New Case Study:See how Anthropic automated 95% of dependency reviews with Socket.Learn More
Socket
Sign inDemoInstall
Socket

@google-cloud/vertexai

Package Overview
Dependencies
Maintainers
2
Versions
25
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

@google-cloud/vertexai - npm Package Compare versions

Comparing version 0.2.1 to 0.3.0

2

.release-please-manifest.json
{
".": "0.2.1"
".": "0.3.0"
}

@@ -10,3 +10,4 @@ {

"api_shortname": "aiplatform",
"library_type": "GAPIC_MANUAL"
"library_type": "GAPIC_MANUAL",
"client_documentation": "https://cloud.google.com/nodejs/docs/reference/vertexai/latest"
}

@@ -18,3 +18,3 @@ /**

import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';
import { Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, GenerateContentResult, GenerationConfig, ModelParams, Part, SafetySetting, StreamGenerateContentResult, VertexInit } from './types/content';
import { Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, GenerateContentResult, GenerationConfig, ModelParams, Part, SafetySetting, StreamGenerateContentResult, Tool, VertexInit } from './types/content';
export * from './types';

@@ -65,2 +65,3 @@ /**

getGenerativeModel(modelParams: ModelParams): GenerativeModel;
validateGoogleAuthOptions(project: string, googleAuthOptions?: GoogleAuthOptions): GoogleAuthOptions;
}

@@ -77,2 +78,3 @@ /**

generation_config?: GenerationConfig;
tools?: Tool[];
}

@@ -102,2 +104,3 @@ /**

safety_settings?: SafetySetting[];
tools?: Tool[];
get history(): Content[];

@@ -115,3 +118,3 @@ /**

sendMessage(request: string | Array<string | Part>): Promise<GenerateContentResult>;
appendHistory(streamGenerateContentResultPromise: Promise<StreamGenerateContentResult>, newContent: Content): Promise<void>;
appendHistory(streamGenerateContentResultPromise: Promise<StreamGenerateContentResult>, newContent: Content[]): Promise<void>;
/**

@@ -133,2 +136,3 @@ * Make an async call to stream send message. Response will be returned in stream.

safety_settings?: SafetySetting[];
tools?: Tool[];
private _vertex_instance;

@@ -145,3 +149,3 @@ private _use_non_stream;

*/
constructor(vertex_instance: VertexAI_Preview, model: string, generation_config?: GenerationConfig, safety_settings?: SafetySetting[]);
constructor(vertex_instance: VertexAI_Preview, model: string, generation_config?: GenerationConfig, safety_settings?: SafetySetting[], tools?: Tool[]);
/**

@@ -148,0 +152,0 @@ * Make a async call to generate content.

@@ -79,18 +79,3 @@ "use strict";

this.googleAuthOptions = googleAuthOptions;
let opts;
if (!googleAuthOptions) {
opts = {
scopes: 'https://www.googleapis.com/auth/cloud-platform',
};
}
else {
if (googleAuthOptions.projectId &&
googleAuthOptions.projectId !== project) {
throw new Error(`inconsistent project ID values. argument project got value ${project} but googleAuthOptions.projectId got value ${googleAuthOptions.projectId}`);
}
opts = googleAuthOptions;
if (!opts.scopes) {
opts.scopes = 'https://www.googleapis.com/auth/cloud-platform';
}
}
const opts = this.validateGoogleAuthOptions(project, googleAuthOptions);
this.project = project;

@@ -126,4 +111,28 @@ this.location = location;

}
return new GenerativeModel(this, modelParams.model, modelParams.generation_config, modelParams.safety_settings);
return new GenerativeModel(this, modelParams.model, modelParams.generation_config, modelParams.safety_settings, modelParams.tools);
}
validateGoogleAuthOptions(project, googleAuthOptions) {
let opts;
const requiredScope = 'https://www.googleapis.com/auth/cloud-platform';
if (!googleAuthOptions) {
opts = {
scopes: requiredScope,
};
return opts;
}
if (googleAuthOptions.projectId &&
googleAuthOptions.projectId !== project) {
throw new Error(`inconsistent project ID values. argument project got value ${project} but googleAuthOptions.projectId got value ${googleAuthOptions.projectId}`);
}
opts = googleAuthOptions;
if (!opts.scopes) {
opts.scopes = requiredScope;
return opts;
}
if ((typeof opts.scopes === 'string' && opts.scopes !== requiredScope) ||
(Array.isArray(opts.scopes) && opts.scopes.indexOf(requiredScope) < 0)) {
throw new errors_1.GoogleAuthError(`input GoogleAuthOptions.scopes ${opts.scopes} doesn't contain required scope ${requiredScope}, please include ${requiredScope} into GoogleAuthOptions.scopes or leave GoogleAuthOptions.scopes undefined`);
}
return opts;
}
}

@@ -152,2 +161,5 @@ exports.VertexAI_Preview = VertexAI_Preview;

this._vertex_instance = request._vertex_instance;
this.generation_config = request.generation_config;
this.safety_settings = request.safety_settings;
this.tools = request.tools;
}

@@ -160,7 +172,8 @@ /**

async sendMessage(request) {
const newContent = formulateNewContent(request);
const newContent = formulateNewContentFromSendMessageRequest(request);
const generateContentrequest = {
contents: this.historyInternal.concat([newContent]),
contents: this.historyInternal.concat(newContent),
safety_settings: this.safety_settings,
generation_config: this.generation_config,
tools: this.tools,
};

@@ -175,3 +188,3 @@ const generateContentResult = await this._model_instance

if (generateContentResponse.candidates.length !== 0) {
this.historyInternal.push(newContent);
this.historyInternal = this.historyInternal.concat(newContent);
const contentFromAssistant = generateContentResponse.candidates[0].content;

@@ -194,3 +207,3 @@ if (!contentFromAssistant.role) {

if (streamGenerateContentResponse.candidates.length !== 0) {
this.historyInternal.push(newContent);
this.historyInternal = this.historyInternal.concat(newContent);
const contentFromAssistant = streamGenerateContentResponse.candidates[0].content;

@@ -213,7 +226,8 @@ if (!contentFromAssistant.role) {

async sendMessageStream(request) {
const newContent = formulateNewContent(request);
const newContent = formulateNewContentFromSendMessageRequest(request);
const generateContentrequest = {
contents: this.historyInternal.concat([newContent]),
contents: this.historyInternal.concat(newContent),
safety_settings: this.safety_settings,
generation_config: this.generation_config,
tools: this.tools,
};

@@ -246,3 +260,3 @@ const streamGenerateContentResultPromise = this._model_instance

*/
constructor(vertex_instance, model, generation_config, safety_settings) {
constructor(vertex_instance, model, generation_config, safety_settings, tools) {
this._use_non_stream = false;

@@ -253,2 +267,3 @@ this._vertex_instance = vertex_instance;

this.safety_settings = safety_settings;
this.tools = tools;
if (model.startsWith('models/')) {

@@ -267,5 +282,5 @@ this.publisherModelEndpoint = `publishers/google/${this.model}`;

async generateContent(request) {
var _a, _b;
var _a, _b, _c;
request = formatContentRequest(request, this.generation_config, this.safety_settings);
validateGcsInput(request.contents);
validateGenerateContentRequest(request);
if (request.generation_config) {

@@ -287,2 +302,3 @@ request.generation_config = validateGenerationConfig(request.generation_config);

safety_settings: (_b = request.safety_settings) !== null && _b !== void 0 ? _b : this.safety_settings,
tools: (_c = request.tools) !== null && _c !== void 0 ? _c : [],
};

@@ -310,5 +326,5 @@ const response = await (0, util_1.postRequest)({

async generateContentStream(request) {
var _a, _b;
var _a, _b, _c;
request = formatContentRequest(request, this.generation_config, this.safety_settings);
validateGcsInput(request.contents);
validateGenerateContentRequest(request);
if (request.generation_config) {

@@ -321,2 +337,3 @@ request.generation_config = validateGenerationConfig(request.generation_config);

safety_settings: (_b = request.safety_settings) !== null && _b !== void 0 ? _b : this.safety_settings,
tools: (_c = request.tools) !== null && _c !== void 0 ? _c : [],
};

@@ -366,3 +383,3 @@ const response = await (0, util_1.postRequest)({

startChat(request) {
var _a, _b;
var _a, _b, _c;
const startChatRequest = {

@@ -378,2 +395,3 @@ _vertex_instance: this._vertex_instance,

(_b = request.safety_settings) !== null && _b !== void 0 ? _b : this.safety_settings;
startChatRequest.tools = (_c = request.tools) !== null && _c !== void 0 ? _c : this.tools;
}

@@ -384,3 +402,3 @@ return new ChatSession(startChatRequest);

exports.GenerativeModel = GenerativeModel;
function formulateNewContent(request) {
function formulateNewContentFromSendMessageRequest(request) {
let newParts = [];

@@ -400,5 +418,38 @@ if (typeof request === 'string') {

}
const newContent = { role: util_1.constants.USER_ROLE, parts: newParts };
return newContent;
return assignRoleToPartsAndValidateSendMessageRequest(newParts);
}
/**
* When multiple Part types (i.e. FunctionResponsePart and TextPart) are
* passed in a single Part array, we may need to assign different roles to each
* part. Currently only FunctionResponsePart requires a role other than 'user'.
* @ignore
* @param {Array<Part>} parts Array of parts to pass to the model
* @return {Content[]} Array of content items
*/
function assignRoleToPartsAndValidateSendMessageRequest(parts) {
const userContent = { role: util_1.constants.USER_ROLE, parts: [] };
const functionContent = { role: util_1.constants.FUNCTION_ROLE, parts: [] };
let hasUserContent = false;
let hasFunctionContent = false;
for (const part of parts) {
if ('functionResponse' in part) {
functionContent.parts.push(part);
hasFunctionContent = true;
}
else {
userContent.parts.push(part);
hasUserContent = true;
}
}
if (hasUserContent && hasFunctionContent) {
throw new errors_1.ClientError('Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.');
}
if (!hasUserContent && !hasFunctionContent) {
throw new errors_1.ClientError('No content is provided for sending chat message.');
}
if (hasUserContent) {
return [userContent];
}
return [functionContent];
}
function throwErrorIfNotOK(response) {

@@ -431,2 +482,20 @@ if (response === undefined) {

}
function validateFunctionResponseRequest(contents) {
const lastestContentPart = contents[contents.length - 1].parts[0];
if (!('functionResponse' in lastestContentPart)) {
return;
}
const errorMessage = 'Please ensure that function response turn comes immediately after a function call turn.';
if (contents.length < 2) {
throw new errors_1.ClientError(errorMessage);
}
const secondLastestContentPart = contents[contents.length - 2].parts[0];
if (!('functionCall' in secondLastestContentPart)) {
throw new errors_1.ClientError(errorMessage);
}
}
function validateGenerateContentRequest(request) {
validateGcsInput(request.contents);
validateFunctionResponseRequest(request.contents);
}
function validateGenerationConfig(generation_config) {

@@ -433,0 +502,0 @@ if ('top_k' in generation_config) {

@@ -162,2 +162,9 @@ "use strict";

}
if (part.functionCall) {
aggregatedResponse.candidates[i].content.parts[0].functionCall =
part.functionCall;
// the empty 'text' key should be removed if functionCall is in the
// response
delete aggregatedResponse.candidates[i].content.parts[0].text;
}
}

@@ -164,0 +171,0 @@ }

@@ -73,2 +73,3 @@ /**

generation_config?: GenerationConfig;
tools?: Tool[];
}

@@ -194,2 +195,4 @@ /**

* 3. file_data
* 4. functionResponse
* 5. functionCall
*/

@@ -203,2 +206,6 @@ export interface BasePart {

* @property {never} - [file_data]. file_data is not expected for TextPart.
* @property {never} - [functionResponse]. functionResponse is not expected for
* TextPart.
* @property {never} - [functionCall]. functionCall is not expected for
* TextPart.
*

@@ -210,2 +217,4 @@ */

file_data?: never;
functionResponse?: never;
functionCall?: never;
}

@@ -215,4 +224,11 @@ /**

* @property {never} - [text]. text is not expected for InlineDataPart.
* @property {GenerativeContentBlob} - inline_data. Only this property is expected for InlineDataPart. {@link GenerativeContentBlob}
* @property {never} - [file_data]. file_data is not expected for InlineDataPart.
* @property {GenerativeContentBlob} - inline_data. Only this property is
* expected for InlineDataPart. {@link GenerativeContentBlob}
* @property {never} - [file_data]. file_data is not expected for
* InlineDataPart.
* @property {never} - [functionResponse]. functionResponse is not expected for
* InlineDataPart.
* @property {never} - [functionCall]. functionCall is not expected for
* InlineDataPart.
*
*/

@@ -223,2 +239,4 @@ export interface InlineDataPart extends BasePart {

file_data?: never;
functionResponse?: never;
functionCall?: never;
}

@@ -237,4 +255,11 @@ /**

* @property {never} - [text]. text is not expected for FileDataPart.
* @property {never} - [inline_data]. inline_data is not expected for FileDataPart.
* @property {FileData} - file_data. Only this property is expected for FileDataPart. {@link FileData}
* @property {never} - [inline_data]. inline_data is not expected for
* FileDataPart.
* @property {FileData} - file_data. Only this property is expected for
* FileDataPart. {@link FileData}
* @property {never} - [functionResponse]. functionResponse is not expected for
* FileDataPart.
* @property {never} - [functionCall]. functionCall is not expected for
* FileDataPart.
*
*/

@@ -245,12 +270,56 @@ export interface FileDataPart extends BasePart {

file_data: FileData;
functionResponse?: never;
functionCall?: never;
}
/**
* A datatype containing media that is part of a multi-part {@link Content} message.
* A `Part` is a union type of {@link TextPart}, {@link InlineDataPart} and {@link FileDataPart}
* A `Part` has one of the following mutually exclusive fields:
* A function response part of a conversation with the model.
* @property {never} - [text]. text is not expected for FunctionResponsePart.
* @property {never} - [inline_data]. inline_data is not expected for
* FunctionResponsePart.
* @property {FileData} - [file_data]. file_data is not expected for
* FunctionResponsePart. {@link FileData}
* @property {never} - functionResponse. only functionResponse is expected for
* FunctionResponsePart.
* @property {never} - [functionCall]. functionCall is not expected for
* FunctionResponsePart.
*
*/
export interface FunctionResponsePart extends BasePart {
text?: never;
inline_data?: never;
file_data?: never;
functionResponse: FunctionResponse;
functionCall?: never;
}
/**
* A function call part of a conversation with the model.
* @property {never} - [text]. text is not expected for FunctionResponsePart.
* @property {never} - [inline_data]. inline_data is not expected for
* FunctionResponsePart.
* @property {never} - [file_data]. file_data is not expected for
* FunctionResponsePart. {@link FileData}
* @property {never} - [functionResponse]. functionResponse is not expected for
* FunctionResponsePart.
* @property {FunctionCall} - functionCall. only functionCall is expected for
* FunctionCallPart.
*
*/
export interface FunctionCallPart extends BasePart {
text?: never;
inline_data?: never;
file_data?: never;
functionResponse?: never;
functionCall: FunctionCall;
}
/**
* A datatype containing media that is part of a multi-part {@link Content}
* message. A `Part` is a union type of {@link TextPart}, {@link
* InlineDataPart}, {@link FileDataPart}, and {@link FunctionResponsePart}. A
* `Part` has one of the following mutually exclusive fields:
* 1. text
* 2. inline_data
* 3. file_data
* 4. functionResponse
*/
export declare type Part = TextPart | InlineDataPart | FileDataPart;
export declare type Part = TextPart | InlineDataPart | FileDataPart | FunctionResponsePart | FunctionCallPart;
/**

@@ -380,2 +449,3 @@ * Raw media bytes sent directly in the request. Text should not be sent as

citationMetadata?: CitationMetadata;
functionCall?: FunctionCall;
}

@@ -402,1 +472,118 @@ /**

}
/**
* A predicted FunctionCall returned from the model that contains a string
* representating the FunctionDeclaration.name with the parameters and their
* values.
* @property {string} - name The name of the function specified in
* FunctionDeclaration.name.
* @property {object} - args The arguments to pass to the function.
*/
export declare interface FunctionCall {
name: string;
args: object;
}
/**
* The result output of a FunctionCall that contains a string representing
* the FunctionDeclaration.name and a structured JSON object containing any
* output from the function call. It is used as context to the model.
* @property {string} - name The name of the function specified in
* FunctionDeclaration.name.
* @property {object} - response The expected response from the model.
*/
export declare interface FunctionResponse {
name: string;
response: object;
}
/**
* Structured representation of a function declaration as defined by the
* [OpenAPI 3.0 specification](https://spec.openapis.org/oas/v3.0.3). Included
* in this declaration are the function name and parameters. This
* FunctionDeclaration is a representation of a block of code that can be used
* as a Tool by the model and executed by the client.
* @property {string} - name The name of the function to call. Must start with a
* letter or an underscore. Must be a-z, A-Z, 0-9, or contain underscores and
* dashes, with a max length of 64.
* @property {string} - description Description and purpose of the function.
* Model uses it to decide how and whether to call the function.
* @property {FunctionDeclarationSchema} - parameters Describes the parameters
* to this function in JSON Schema Object format. Reflects the Open API 3.03
* Parameter Object. string Key: the name of the parameter. Parameter names are
* case sensitive. Schema Value: the Schema defining the type used for the
* parameter. For function with no parameters, this can be left unset. Example
* with 1 required and 1 optional parameter: type: OBJECT properties:
param1:
type: STRING
param2:
type: INTEGER
required:
- param1
*/
export declare interface FunctionDeclaration {
name: string;
description?: string;
parameters?: FunctionDeclarationSchema;
}
/**
* A Tool is a piece of code that enables the system to interact with
* external systems to perform an action, or set of actions, outside of
* knowledge and scope of the model.
* @property {object} - function_declarations One or more function declarations
* to be passed to the model along with the current user query. Model may decide
* to call a subset of these functions by populating
* [FunctionCall][content.part.function_call] in the response. User should
* provide a [FunctionResponse][content.part.function_response] for each
* function call in the next turn. Based on the function responses, Model will
* generate the final response back to the user. Maximum 64 function
* declarations can be provided.
*/
export declare interface Tool {
function_declarations: FunctionDeclaration[];
}
/**
* Contains the list of OpenAPI data types
* as defined by https://swagger.io/docs/specification/data-models/data-types/
* @public
*/
export declare enum FunctionDeclarationSchemaType {
STRING = "STRING",
NUMBER = "NUMBER",
INTEGER = "INTEGER",
BOOLEAN = "BOOLEAN",
ARRAY = "ARRAY",
OBJECT = "OBJECT"
}
/**
* Schema for parameters passed to [FunctionDeclaration.parameters]
* @public
*/
export interface FunctionDeclarationSchema {
type: FunctionDeclarationSchemaType;
properties: {
[k: string]: FunctionDeclarationSchemaProperty;
};
description?: string;
required?: string[];
}
/**
* Schema is used to define the format of input/output data.
* Represents a select subset of an OpenAPI 3.0 schema object.
* More fields may be added in the future as needed.
* @public
*/
export interface FunctionDeclarationSchemaProperty {
type?: FunctionDeclarationSchemaType;
format?: string;
description?: string;
nullable?: boolean;
items?: FunctionDeclarationSchema;
enum?: string[];
properties?: {
[k: string]: FunctionDeclarationSchema;
};
required?: string[];
example?: unknown;
}

@@ -19,3 +19,3 @@ "use strict";

Object.defineProperty(exports, "__esModule", { value: true });
exports.FinishReason = exports.BlockedReason = exports.HarmProbability = exports.HarmBlockThreshold = exports.HarmCategory = void 0;
exports.FunctionDeclarationSchemaType = exports.FinishReason = exports.BlockedReason = exports.HarmProbability = exports.HarmBlockThreshold = exports.HarmCategory = void 0;
/**

@@ -143,2 +143,16 @@ * @enum {string}

})(FinishReason || (exports.FinishReason = FinishReason = {}));
/**
* Contains the list of OpenAPI data types
* as defined by https://swagger.io/docs/specification/data-models/data-types/
* @public
*/
var FunctionDeclarationSchemaType;
(function (FunctionDeclarationSchemaType) {
FunctionDeclarationSchemaType["STRING"] = "STRING";
FunctionDeclarationSchemaType["NUMBER"] = "NUMBER";
FunctionDeclarationSchemaType["INTEGER"] = "INTEGER";
FunctionDeclarationSchemaType["BOOLEAN"] = "BOOLEAN";
FunctionDeclarationSchemaType["ARRAY"] = "ARRAY";
FunctionDeclarationSchemaType["OBJECT"] = "OBJECT";
})(FunctionDeclarationSchemaType || (exports.FunctionDeclarationSchemaType = FunctionDeclarationSchemaType = {}));
//# sourceMappingURL=content.js.map

@@ -21,2 +21,3 @@ /**

export declare const MODEL_ROLE = "model";
export declare const USER_AGENT = "model-builder/0.2.1 grpc-node/0.2.1";
export declare const FUNCTION_ROLE = "function";
export declare const USER_AGENT = "model-builder/0.3.0 grpc-node/0.3.0";
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.USER_AGENT = exports.MODEL_ROLE = exports.USER_ROLE = exports.STREAMING_GENERATE_CONTENT_METHOD = exports.GENERATE_CONTENT_METHOD = void 0;
exports.USER_AGENT = exports.FUNCTION_ROLE = exports.MODEL_ROLE = exports.USER_ROLE = exports.STREAMING_GENERATE_CONTENT_METHOD = exports.GENERATE_CONTENT_METHOD = void 0;
/**

@@ -24,6 +24,7 @@ * @license

exports.MODEL_ROLE = 'model';
exports.FUNCTION_ROLE = 'function';
const USER_AGENT_PRODUCT = 'model-builder';
const CLIENT_LIBRARY_VERSION = '0.2.1'; // x-release-please-version
const CLIENT_LIBRARY_VERSION = '0.3.0'; // x-release-please-version
const CLIENT_LIBRARY_LANGUAGE = `grpc-node/${CLIENT_LIBRARY_VERSION}`;
exports.USER_AGENT = `${USER_AGENT_PRODUCT}/${CLIENT_LIBRARY_VERSION} ${CLIENT_LIBRARY_LANGUAGE}`;
//# sourceMappingURL=constants.js.map

@@ -20,5 +20,4 @@ "use strict";

// @ts-ignore
const assert = require("assert");
const src_1 = require("../src");
// TODO: this env var isn't getting populated correctly
const types_1 = require("../src/types");
const PROJECT = process.env.GCLOUD_PROJECT;

@@ -34,3 +33,3 @@ const LOCATION = 'us-central1';

file_data: {
file_uri: 'gs://nodejs_vertex_system_test_resources/scones.jpg',
file_uri: 'gs://generativeai-downloads/images/scones.jpg',
mime_type: 'image/jpeg',

@@ -52,4 +51,44 @@ },

};
const FUNCTION_CALL_NAME = 'get_current_weather';
const TOOLS_WITH_FUNCTION_DECLARATION = [
{
function_declarations: [
{
name: FUNCTION_CALL_NAME,
description: 'get weather in a given location',
parameters: {
type: types_1.FunctionDeclarationSchemaType.OBJECT,
properties: {
location: { type: types_1.FunctionDeclarationSchemaType.STRING },
unit: {
type: types_1.FunctionDeclarationSchemaType.STRING,
enum: ['celsius', 'fahrenheit'],
},
},
required: ['location'],
},
},
],
},
];
const WEATHER_FORECAST = 'super nice';
const FUNCTION_RESPONSE_PART = [
{
functionResponse: {
name: FUNCTION_CALL_NAME,
response: {
name: FUNCTION_CALL_NAME,
content: { weather: WEATHER_FORECAST },
},
},
},
];
const FUNCTION_CALL = [
{ functionCall: { name: FUNCTION_CALL_NAME, args: { location: 'boston' } } },
];
// Initialize Vertex with your Cloud project and location
const vertex_ai = new src_1.VertexAI({ project: 'long-door-651', location: LOCATION });
const vertex_ai = new src_1.VertexAI({
project: PROJECT,
location: LOCATION,
});
const generativeTextModel = vertex_ai.preview.getGenerativeModel({

@@ -76,7 +115,5 @@ model: 'gemini-pro',

});
// TODO (b/316599049): update tests to use jasmine expect syntax:
// expect(...).toBeInstanceOf(...)
describe('generateContentStream', () => {
beforeEach(() => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000;
jasmine.DEFAULT_TIMEOUT_INTERVAL = 30000;
});

@@ -86,6 +123,6 @@ it('should should return a stream and aggregated response when passed text', async () => {

for await (const item of streamingResp.stream) {
assert(item.candidates[0], `sys test failure on generateContentStream, for item ${item}`);
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream, for item ${item}`);
}
const aggregatedResp = await streamingResp.response;
assert(aggregatedResp.candidates[0], `sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`);
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`);
});

@@ -97,6 +134,6 @@ it('should not return a invalid unicode', async () => {

for await (const item of streamingResp.stream) {
assert(item.candidates[0], `sys test failure on generateContentStream, for item ${item}`);
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream, for item ${item}`);
for (const candidate of item.candidates) {
for (const part of candidate.content.parts) {
assert(!part.text.includes('\ufffd'), `sys test failure on generateContentStream, for item ${item}`);
expect(part.text).not.toContain('\ufffd', `sys test failure on generateContentStream, for item ${item}`);
}

@@ -106,3 +143,3 @@ }

const aggregatedResp = await streamingResp.response;
assert(aggregatedResp.candidates[0], `sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`);
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`);
});

@@ -112,6 +149,6 @@ it('should return a stream and aggregated response when passed multipart base64 content', async () => {

for await (const item of streamingResp.stream) {
assert(item.candidates[0], `sys test failure on generateContentStream, for item ${item}`);
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream, for item ${item}`);
}
const aggregatedResp = await streamingResp.response;
assert(aggregatedResp.candidates[0], `sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`);
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`);
});

@@ -131,23 +168,55 @@ it('should throw ClientError when having invalid input', async () => {

await generativeVisionModel.generateContentStream(badRequest).catch(e => {
assert(e instanceof src_1.ClientError, `sys test failure on generateContentStream when having bad request should throw ClientError but actually thrown ${e}`);
assert(e.message === '[VertexAI.ClientError]: got status: 400 Bad Request', `sys test failure on generateContentStream when having bad request got wrong error message: ${e.message}`);
expect(e).toBeInstanceOf(src_1.ClientError);
expect(e.message).toBe('[VertexAI.ClientError]: got status: 400 Bad Request', `sys test failure on generateContentStream when having bad request
got wrong error message: ${e.message}`);
});
});
// TODO: this is returning a 500 on the system test project
// it('should should return a stream and aggregated response when passed
// multipart GCS content',
// async () => {
// const streamingResp = await
// generativeVisionModel.generateContentStream(
// MULTI_PART_GCS_REQUEST);
// for await (const item of streamingResp.stream) {
// assert(item.candidates[0]);
// console.log('stream chunk: ', item);
// }
// const aggregatedResp = await streamingResp.response;
// assert(aggregatedResp.candidates[0]);
// console.log('aggregated response: ', aggregatedResp);
// });
it('should should return a stream and aggregated response when passed multipart GCS content', async () => {
const streamingResp = await generativeVisionModel.generateContentStream(MULTI_PART_GCS_REQUEST);
for await (const item of streamingResp.stream) {
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream, for item ${item}`);
}
const aggregatedResp = await streamingResp.response;
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`);
});
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => {
var _a;
const request = {
contents: [
{ role: 'user', parts: [{ text: 'What is the weather in Boston?' }] },
{ role: 'model', parts: FUNCTION_CALL },
{ role: 'function', parts: FUNCTION_RESPONSE_PART },
],
tools: TOOLS_WITH_FUNCTION_DECLARATION,
};
const streamingResp = await generativeTextModel.generateContentStream(request);
for await (const item of streamingResp.stream) {
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream, for item ${item}`);
expect((_a = item.candidates[0].content.parts[0].text) === null || _a === void 0 ? void 0 : _a.toLowerCase()).toContain(WEATHER_FORECAST);
}
});
});
// TODO (b/316599049): add tests for generateContent and sendMessage
describe('generateContent', () => {
beforeEach(() => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000;
});
it('should return the aggregated response', async () => {
const response = await generativeTextModel.generateContent(TEXT_REQUEST);
const aggregatedResp = await response.response;
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`);
});
});
describe('sendMessage', () => {
beforeEach(() => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000;
});
it('should populate history and return a chat response', async () => {
const chat = generativeTextModel.startChat();
const chatInput1 = 'How can I learn more about Node.js?';
const result1 = await chat.sendMessage(chatInput1);
const response1 = await result1.response;
expect(response1.candidates[0]).toBeTruthy(`sys test failure on sendMessage for aggregated response: ${response1}`);
expect(chat.history.length).toBe(2);
});
});
describe('sendMessageStream', () => {

@@ -166,6 +235,6 @@ beforeEach(() => {

for await (const item of result1.stream) {
assert(item.candidates[0], `sys test failure on sendMessageStream, for item ${item}`);
expect(item.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream, for item ${item}`);
}
const resp = await result1.response;
assert(resp.candidates[0], `sys test failure on sendMessageStream for aggregated response: ${resp}`);
expect(resp.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream for aggregated response: ${resp}`);
expect(chat.history.length).toBe(2);

@@ -178,6 +247,6 @@ });

for await (const item of result1.stream) {
assert(item.candidates[0], `sys test failure on sendMessageStream, for item ${item}`);
expect(item.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream, for item ${item}`);
}
const resp = await result1.response;
assert(resp.candidates[0], `sys test failure on sendMessageStream for aggregated response: ${resp}`);
expect(resp.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream for aggregated response: ${resp}`);
expect(chat.history.length).toBe(2);

@@ -201,2 +270,22 @@ });

});
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => {
const chat = generativeTextModel.startChat({
tools: TOOLS_WITH_FUNCTION_DECLARATION,
});
const chatInput1 = 'What is the weather in Boston?';
const result1 = await chat.sendMessageStream(chatInput1);
for await (const item of result1.stream) {
expect(item.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream with function calling, for item ${item}`);
}
const response1 = await result1.response;
expect(JSON.stringify(response1.candidates[0].content.parts[0].functionCall)).toContain(FUNCTION_CALL_NAME);
expect(JSON.stringify(response1.candidates[0].content.parts[0].functionCall)).toContain('location');
// Send a follow up message with a FunctionResponse
const result2 = await chat.sendMessageStream(FUNCTION_RESPONSE_PART);
for await (const item of result2.stream) {
expect(item.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream with function calling, for item ${item}`);
}
const response2 = await result2.response;
expect(JSON.stringify(response2.candidates[0].content.parts[0].text)).toContain(WEATHER_FORECAST);
});
});

@@ -206,3 +295,3 @@ describe('countTokens', () => {

const countTokensResp = await generativeTextModel.countTokens(TEXT_REQUEST);
assert(countTokensResp.totalTokens, `sys test failure on countTokens, ${countTokensResp}`);
expect(countTokensResp.totalTokens).toBeTruthy(`sys test failure on countTokens, ${countTokensResp}`);
});

@@ -212,3 +301,3 @@ });

beforeEach(() => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 25000;
jasmine.DEFAULT_TIMEOUT_INTERVAL = 30000;
});

@@ -218,6 +307,6 @@ it('should should return a stream and aggregated response when passed text', async () => {

for await (const item of streamingResp.stream) {
assert(item.candidates[0], `sys test failure on generateContentStream using models/gemini-pro, for item ${item}`);
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream using models/gemini-pro, for item ${item}`);
}
const aggregatedResp = await streamingResp.response;
assert(aggregatedResp.candidates[0], `sys test failure on generateContentStream using models/gemini-pro for aggregated response: ${aggregatedResp}`);
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream using models/gemini-pro for aggregated response: ${aggregatedResp}`);
});

@@ -227,8 +316,8 @@ it('should return a stream and aggregated response when passed multipart base64 content when using models/gemini-pro-vision', async () => {

for await (const item of streamingResp.stream) {
assert(item.candidates[0], `sys test failure on generateContentStream using models/gemini-pro-vision, for item ${item}`);
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream using models/gemini-pro-vision, for item ${item}`);
}
const aggregatedResp = await streamingResp.response;
assert(aggregatedResp.candidates[0], `sys test failure on generateContentStream using models/gemini-pro-vision for aggregated response: ${aggregatedResp}`);
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream using models/gemini-pro-vision for aggregated response: ${aggregatedResp}`);
});
});
//# sourceMappingURL=end_to_end_sample_test.js.map

@@ -23,1 +23,2 @@ /**

export declare function testGenerator(): AsyncGenerator<GenerateContentResponse>;
export declare function testGeneratorWithEmptyResponse(): AsyncGenerator<GenerateContentResponse>;

@@ -19,3 +19,3 @@ "use strict";

Object.defineProperty(exports, "__esModule", { value: true });
exports.testGenerator = void 0;
exports.testGeneratorWithEmptyResponse = exports.testGenerator = void 0;
/* tslint:disable */

@@ -25,2 +25,3 @@ const index_1 = require("../src/index");

const content_1 = require("../src/types/content");
const errors_1 = require("../src/types/errors");
const util_1 = require("../src/util");

@@ -98,2 +99,34 @@ const PROJECT = 'test_project';

};
const TEST_FUNCTION_CALL_RESPONSE = {
functionCall: {
name: 'get_current_weather',
args: {
location: 'LA',
unit: 'fahrenheit',
},
},
};
const TEST_CANDIDATES_WITH_FUNCTION_CALL = [
{
index: 1,
content: {
role: util_1.constants.MODEL_ROLE,
parts: [TEST_FUNCTION_CALL_RESPONSE],
},
finishReason: content_1.FinishReason.STOP,
finishMessage: '',
safetyRatings: TEST_SAFETY_RATINGS,
},
];
const TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL = {
candidates: TEST_CANDIDATES_WITH_FUNCTION_CALL,
};
const TEST_FUNCTION_RESPONSE_PART = [
{
functionResponse: {
name: 'get_current_weather',
response: { name: 'get_current_weather', content: { weather: 'super nice' } },
},
},
];
const TEST_CANDIDATES_MISSING_ROLE = [

@@ -116,4 +149,2 @@ {

const TEST_ENDPOINT_BASE_PATH = 'test.googleapis.com';
const TEST_FILENAME = '/tmp/image.jpeg';
const INVALID_FILENAME = 'image.txt';
const TEST_GCS_FILENAME = 'gs://test_bucket/test_image.jpeg';

@@ -129,2 +160,36 @@ const TEST_MULTIPART_MESSAGE = [

];
const BASE_64_IMAGE = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==';
const INLINE_DATA_FILE_PART = {
inline_data: {
data: BASE_64_IMAGE,
mime_type: 'image/jpeg',
},
};
const TEST_MULTIPART_MESSAGE_BASE64 = [
{
role: util_1.constants.USER_ROLE,
parts: [{ text: 'What is in this picture?' }, INLINE_DATA_FILE_PART],
},
];
const TEST_TOOLS_WITH_FUNCTION_DECLARATION = [
{
function_declarations: [
{
name: 'get_current_weather',
description: 'get weather in a given location',
parameters: {
type: content_1.FunctionDeclarationSchemaType.OBJECT,
properties: {
location: { type: content_1.FunctionDeclarationSchemaType.STRING },
unit: {
type: content_1.FunctionDeclarationSchemaType.STRING,
enum: ['celsius', 'fahrenheit'],
},
},
required: ['location'],
},
},
],
},
];
const fetchResponseObj = {

@@ -147,2 +212,8 @@ status: 200,

exports.testGenerator = testGenerator;
async function* testGeneratorWithEmptyResponse() {
yield {
candidates: [],
};
}
exports.testGeneratorWithEmptyResponse = testGeneratorWithEmptyResponse;
describe('VertexAI', () => {

@@ -172,3 +243,3 @@ let vertexai;

const googleAuthOptions = {
scopes: 'test.scopes',
scopes: 'https://www.googleapis.com/auth/cloud-platform',
};

@@ -194,2 +265,28 @@ const vetexai1 = new index_1.VertexAI({

});
it('given scopes missing required scope, should throw GoogleAuthError', () => {
const invalidGoogleAuthOptionsStringScopes = { scopes: 'test.scopes' };
expect(() => {
new index_1.VertexAI({
project: PROJECT,
location: LOCATION,
googleAuthOptions: invalidGoogleAuthOptionsStringScopes,
});
}).toThrow(new errors_1.GoogleAuthError("input GoogleAuthOptions.scopes test.scopes doesn't contain required scope " +
'https://www.googleapis.com/auth/cloud-platform, ' +
'please include https://www.googleapis.com/auth/cloud-platform into GoogleAuthOptions.scopes ' +
'or leave GoogleAuthOptions.scopes undefined'));
const invalidGoogleAuthOptionsArrayScopes = {
scopes: ['test1.scopes', 'test2.scopes'],
};
expect(() => {
new index_1.VertexAI({
project: PROJECT,
location: LOCATION,
googleAuthOptions: invalidGoogleAuthOptionsArrayScopes,
});
}).toThrow(new errors_1.GoogleAuthError("input GoogleAuthOptions.scopes test1.scopes,test2.scopes doesn't contain required scope " +
'https://www.googleapis.com/auth/cloud-platform, ' +
'please include https://www.googleapis.com/auth/cloud-platform into GoogleAuthOptions.scopes ' +
'or leave GoogleAuthOptions.scopes undefined'));
});
describe('generateContent', () => {

@@ -215,4 +312,2 @@ it('returns a GenerateContentResponse', async () => {

});
});
describe('generateContent', () => {
it('returns a GenerateContentResponse when passed a GCS URI', async () => {

@@ -233,4 +328,2 @@ const req = {

});
});
describe('generateContent', () => {
it('raises an error when passed an invalid GCS URI', async () => {

@@ -242,4 +335,2 @@ const req = {

});
});
describe('generateContent', () => {
it('returns a GenerateContentResponse when passed safety_settings and generation_config', async () => {

@@ -262,4 +353,2 @@ const req = {

});
});
describe('generateContent', () => {
it('updates the base API endpoint when provided', async () => {

@@ -289,4 +378,2 @@ const vertexaiWithBasePath = new index_1.VertexAI({

});
});
describe('generateContent', () => {
it('default the base API endpoint when base API not provided', async () => {

@@ -315,4 +402,2 @@ const vertexaiWithoutBasePath = new index_1.VertexAI({

});
});
describe('generateContent', () => {
it('removes top_k when it is set to 0', async () => {

@@ -331,6 +416,2 @@ const reqWithEmptyConfigs = {

};
// const fetchResult = Promise.resolve(
// new Response(JSON.stringify(expectedStreamResult),
// fetchResponseObj));
// const requestSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedStreamResult);

@@ -343,4 +424,2 @@ await model.generateContent(reqWithEmptyConfigs);

});
});
describe('generateContent', () => {
it('includes top_k when it is within 1 - 40', async () => {

@@ -366,4 +445,2 @@ const reqWithEmptyConfigs = {

});
});
describe('generateContent', () => {
it('aggregates citation metadata', async () => {

@@ -386,2 +463,52 @@ var _a;

});
it('returns a FunctionCall when passed a FunctionDeclaration', async () => {
const req = {
contents: [
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] },
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedResult = {
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL,
};
const expectedStreamResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL),
stream: testGenerator(),
};
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedStreamResult);
const resp = await model.generateContent(req);
expect(resp).toEqual(expectedResult);
});
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => {
const req = {
contents: [
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] },
{
role: 'function',
parts: TEST_FUNCTION_RESPONSE_PART,
},
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.';
await model.generateContent(req).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => {
const req = {
contents: [
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] },
{
role: 'function',
parts: TEST_FUNCTION_RESPONSE_PART,
},
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.';
await model.generateContent(req).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
});

@@ -410,4 +537,2 @@ describe('generateContentStream', () => {

});
});
describe('generateContentStream', () => {
it('returns a GenerateContentResponse when passed multi-part content with a GCS URI', async () => {

@@ -425,5 +550,62 @@ const req = {

});
it('returns a GenerateContentResponse when passed multi-part content with base64 data', async () => {
const req = {
contents: TEST_MULTIPART_MESSAGE_BASE64,
};
const expectedResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult);
const resp = await model.generateContentStream(req);
expect(resp).toEqual(expectedResult);
});
it('returns a FunctionCall when passed a FunctionDeclaration', async () => {
const req = {
contents: [
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] },
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedStreamResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL),
stream: testGenerator(),
};
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedStreamResult);
const resp = await model.generateContentStream(req);
expect(resp).toEqual(expectedStreamResult);
});
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => {
const req = {
contents: [
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] },
{
role: 'function',
parts: TEST_FUNCTION_RESPONSE_PART,
},
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.';
await model.generateContentStream(req).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => {
const req = {
contents: [
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] },
{
role: 'function',
parts: TEST_FUNCTION_RESPONSE_PART,
},
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.';
await model.generateContentStream(req).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
});
// TODO: add a streaming test with a multipart message and inline image data
// (b64 string)
describe('startChat', () => {

@@ -466,2 +648,4 @@ it('returns a ChatSession when passed a request arg', () => {

let chatSessionWithNoArgs;
let chatSessionWithEmptyResponse;
let chatSessionWithFunctionCall;
let vertexai;

@@ -479,2 +663,6 @@ let model;

chatSessionWithNoArgs = model.startChat();
chatSessionWithEmptyResponse = model.startChat();
chatSessionWithFunctionCall = model.startChat({
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
});
expectedStreamResult = {

@@ -516,8 +704,3 @@ response: Promise.resolve(TEST_MODEL_RESPONSE),

});
// TODO: unbreak this test. Currently chatSession.history is saving the
// history from the test above instead of resetting and
// expect.toThrowError() is erroring out before the expect condition is
// called
it('throws an error when the model returns an empty response', async () => {
// Reset the chat session history
const req = 'How are you doing today?';

@@ -528,2 +711,15 @@ const expectedResult = {

const expectedStreamResult = {
response: Promise.resolve(TEST_EMPTY_MODEL_RESPONSE),
stream: testGeneratorWithEmptyResponse(),
};
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedStreamResult);
await expectAsync(chatSessionWithEmptyResponse.sendMessage(req)).toBeRejected();
expect(chatSessionWithEmptyResponse.history.length).toEqual(0);
});
it('returns a GenerateContentResponse when passed multi-part content', async () => {
const req = TEST_MULTIPART_MESSAGE[0]['parts'];
const expectedResult = {
response: TEST_MODEL_RESPONSE,
};
const expectedStreamResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),

@@ -533,11 +729,52 @@ stream: testGenerator(),

spyOn(StreamFunctions, 'processStream').and.returnValue(expectedStreamResult);
// Shouldn't append anything to history with an empty result
// expect(chatSession.history.length).toEqual(1);
// expect(await chatSession.sendMessage(req))
// .toThrowError('Did not get a response from the model');
const resp = await chatSessionWithNoArgs.sendMessage(req);
expect(resp).toEqual(expectedResult);
console.log(chatSessionWithNoArgs.history, 'hihii');
expect(chatSessionWithNoArgs.history.length).toEqual(2);
});
// TODO: add test cases for different content types passed to
// sendMessage
it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => {
const functionCallChatMessage = 'What is the weather in LA?';
const expectedFunctionCallResponse = {
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL,
};
const expectedStreamResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL),
stream: testGenerator(),
};
const streamSpy = spyOn(StreamFunctions, 'processStream');
streamSpy.and.returnValue(expectedStreamResult);
const response1 = await chatSessionWithFunctionCall.sendMessage(functionCallChatMessage);
expect(response1).toEqual(expectedFunctionCallResponse);
expect(chatSessionWithFunctionCall.history.length).toEqual(2);
// Send a follow-up message with a FunctionResponse
const expectedFollowUpResponse = {
response: TEST_MODEL_RESPONSE,
};
const expectedFollowUpStreamResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
streamSpy.and.returnValue(expectedFollowUpStreamResult);
const response2 = await chatSessionWithFunctionCall.sendMessage(TEST_FUNCTION_RESPONSE_PART);
expect(response2).toEqual(expectedFollowUpResponse);
expect(chatSessionWithFunctionCall.history.length).toEqual(4);
});
it('throw ClientError when request has no content', async () => {
const expectedErrorMessage = '[VertexAI.ClientError]: No content is provided for sending chat message.';
await chatSessionWithNoArgs.sendMessage([]).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
it('throw ClientError when request mix functionCall part with other types of part', async () => {
const chatRequest = [
'what is the weather like in LA',
TEST_FUNCTION_RESPONSE_PART[0],
];
const expectedErrorMessage = '[VertexAI.ClientError]: Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.';
await chatSessionWithNoArgs.sendMessage(chatRequest).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
});
describe('sendMessageStram', () => {
describe('sendMessageStream', () => {
it('returns a StreamGenerateContentResponse and appends to history', async () => {

@@ -595,2 +832,39 @@ const req = 'How are you doing today?';

});
it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => {
const functionCallChatMessage = 'What is the weather in LA?';
const expectedStreamResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL),
stream: testGenerator(),
};
const streamSpy = spyOn(StreamFunctions, 'processStream');
streamSpy.and.returnValue(expectedStreamResult);
const response1 = await chatSessionWithFunctionCall.sendMessageStream(functionCallChatMessage);
expect(response1).toEqual(expectedStreamResult);
expect(chatSessionWithFunctionCall.history.length).toEqual(2);
// Send a follow-up message with a FunctionResponse
const expectedFollowUpStreamResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
streamSpy.and.returnValue(expectedFollowUpStreamResult);
const response2 = await chatSessionWithFunctionCall.sendMessageStream(TEST_FUNCTION_RESPONSE_PART);
expect(response2).toEqual(expectedFollowUpStreamResult);
expect(chatSessionWithFunctionCall.history.length).toEqual(4);
});
it('throw ClientError when request has no content', async () => {
const expectedErrorMessage = '[VertexAI.ClientError]: No content is provided for sending chat message.';
await chatSessionWithNoArgs.sendMessageStream([]).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
it('throw ClientError when request mix functionCall part with other types of part', async () => {
const chatRequest = [
'what is the weather like in LA',
TEST_FUNCTION_RESPONSE_PART[0],
];
const expectedErrorMessage = '[VertexAI.ClientError]: Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.';
await chatSessionWithNoArgs.sendMessageStream(chatRequest).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
});

@@ -597,0 +871,0 @@ });

# Changelog
## [0.3.0](https://github.com/googleapis/nodejs-vertexai/compare/v0.2.1...v0.3.0) (2024-01-30)
### Features
* add function calling support ([1deb4e9](https://github.com/googleapis/nodejs-vertexai/commit/1deb4e920205d2fff6da780175de6045bd853885))
### Bug Fixes
* throw error when GoogleAuthOptions.scopes doesn't include required scope. ([558aee9](https://github.com/googleapis/nodejs-vertexai/commit/558aee98d76192b4a63b3d28abba3f3d4cda1762))
* throws instructive client side error message when bad request happens for function calling ([c90203d](https://github.com/googleapis/nodejs-vertexai/commit/c90203d153407daa08763c273a827a5e9db54a70))
## [0.2.1](https://github.com/googleapis/nodejs-vertexai/compare/v0.2.0...v0.2.1) (2024-01-05)

@@ -4,0 +17,0 @@

{
"name": "@google-cloud/vertexai",
"description": "Vertex Generative AI client for Node.js",
"version": "0.2.1",
"version": "0.3.0",
"license": "Apache-2.0",

@@ -6,0 +6,0 @@ "author": "Google LLC",

@@ -170,2 +170,102 @@ # Vertex AI Node.js SDK

## Function calling
The Node SDK supports
[function calling](https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/function-calling) via `sendMessage`, `sendMessageStream`, `generateContent`, and `generateContentStream`. We recommend using it through chat methods
(`sendMessage` or `sendMessageStream`) but have included examples of both
approaches below.
### Function declarations and response
This is an example of a function declaration and function response, which are
passed to the model in the snippets that follow.
```typescript
const functionDeclarations = [
{
function_declarations: [
{
name: "get_current_weather",
description: 'get weather in a given location',
parameters: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
location: {type: FunctionDeclarationSchemaType.STRING},
unit: {
type: FunctionDeclarationSchemaType.STRING,
enum: ['celsius', 'fahrenheit'],
},
},
required: ['location'],
},
},
],
},
];
const functionResponseParts = [
{
functionResponse: {
name: "get_current_weather",
response:
{name: "get_current_weather", content: {weather: "super nice"}},
},
},
];
```
### Function calling with chat
```typescript
async function functionCallingChat() {
// Create a chat session and pass your function declarations
const chat = generativeModel.startChat({
tools: functionDeclarations,
});
const chatInput1 = 'What is the weather in Boston?';
// This should include a functionCall response from the model
const result1 = await chat.sendMessageStream(chatInput1);
for await (const item of result1.stream) {
console.log(item.candidates[0]);
}
const response1 = await result1.response;
// Send a follow up message with a FunctionResponse
const result2 = await chat.sendMessageStream(functionResponseParts);
for await (const item of result2.stream) {
console.log(item.candidates[0]);
}
// This should include a text response from the model using the response content
// provided above
const response2 = await result2.response;
}
functionCallingChat();
```
### Function calling with generateContentStream
```typescript
async function functionCallingGenerateContent() {
const request = {
contents: [
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
{role: 'model', parts: [{functionCall: {name: 'get_current_weather', args: {'location': 'Boston'}}}]},
{role: 'function', parts: functionResponseParts}
],
tools: functionDeclarations,
};
const streamingResp =
await generativeModel.generateContentStream(request);
for await (const item of streamingResp.stream) {
console.log(item.candidates[0]);
}
}
functionCallingGenerateContent();
```
## License

@@ -172,0 +272,0 @@

@@ -37,2 +37,3 @@ /**

StreamGenerateContentResult,
Tool,
VertexInit,

@@ -97,21 +98,3 @@ } from './types/content';

) {
let opts: GoogleAuthOptions;
if (!googleAuthOptions) {
opts = {
scopes: 'https://www.googleapis.com/auth/cloud-platform',
};
} else {
if (
googleAuthOptions.projectId &&
googleAuthOptions.projectId !== project
) {
throw new Error(
`inconsistent project ID values. argument project got value ${project} but googleAuthOptions.projectId got value ${googleAuthOptions.projectId}`
);
}
opts = googleAuthOptions;
if (!opts.scopes) {
opts.scopes = 'https://www.googleapis.com/auth/cloud-platform';
}
}
const opts = this.validateGoogleAuthOptions(project, googleAuthOptions);
this.project = project;

@@ -157,5 +140,42 @@ this.location = location;

modelParams.generation_config,
modelParams.safety_settings
modelParams.safety_settings,
modelParams.tools
);
}
validateGoogleAuthOptions(
project: string,
googleAuthOptions?: GoogleAuthOptions
): GoogleAuthOptions {
let opts: GoogleAuthOptions;
const requiredScope = 'https://www.googleapis.com/auth/cloud-platform';
if (!googleAuthOptions) {
opts = {
scopes: requiredScope,
};
return opts;
}
if (
googleAuthOptions.projectId &&
googleAuthOptions.projectId !== project
) {
throw new Error(
`inconsistent project ID values. argument project got value ${project} but googleAuthOptions.projectId got value ${googleAuthOptions.projectId}`
);
}
opts = googleAuthOptions;
if (!opts.scopes) {
opts.scopes = requiredScope;
return opts;
}
if (
(typeof opts.scopes === 'string' && opts.scopes !== requiredScope) ||
(Array.isArray(opts.scopes) && opts.scopes.indexOf(requiredScope) < 0)
) {
throw new GoogleAuthError(
`input GoogleAuthOptions.scopes ${opts.scopes} doesn't contain required scope ${requiredScope}, please include ${requiredScope} into GoogleAuthOptions.scopes or leave GoogleAuthOptions.scopes undefined`
);
}
return opts;
}
}

@@ -173,2 +193,3 @@

generation_config?: GenerationConfig;
tools?: Tool[];
}

@@ -205,2 +226,3 @@

safety_settings?: SafetySetting[];
tools?: Tool[];

@@ -221,2 +243,5 @@ get history(): Content[] {

this._vertex_instance = request._vertex_instance;
this.generation_config = request.generation_config;
this.safety_settings = request.safety_settings;
this.tools = request.tools;
}

@@ -232,7 +257,9 @@

): Promise<GenerateContentResult> {
const newContent: Content = formulateNewContent(request);
const newContent: Content[] =
formulateNewContentFromSendMessageRequest(request);
const generateContentrequest: GenerateContentRequest = {
contents: this.historyInternal.concat([newContent]),
contents: this.historyInternal.concat(newContent),
safety_settings: this.safety_settings,
generation_config: this.generation_config,
tools: this.tools,
};

@@ -249,3 +276,3 @@

if (generateContentResponse.candidates.length !== 0) {
this.historyInternal.push(newContent);
this.historyInternal = this.historyInternal.concat(newContent);
const contentFromAssistant =

@@ -267,3 +294,3 @@ generateContentResponse.candidates[0].content;

streamGenerateContentResultPromise: Promise<StreamGenerateContentResult>,
newContent: Content
newContent: Content[]
): Promise<void> {

@@ -276,3 +303,3 @@ const streamGenerateContentResult =

if (streamGenerateContentResponse.candidates.length !== 0) {
this.historyInternal.push(newContent);
this.historyInternal = this.historyInternal.concat(newContent);
const contentFromAssistant =

@@ -298,7 +325,9 @@ streamGenerateContentResponse.candidates[0].content;

): Promise<StreamGenerateContentResult> {
const newContent: Content = formulateNewContent(request);
const newContent: Content[] =
formulateNewContentFromSendMessageRequest(request);
const generateContentrequest: GenerateContentRequest = {
contents: this.historyInternal.concat([newContent]),
contents: this.historyInternal.concat(newContent),
safety_settings: this.safety_settings,
generation_config: this.generation_config,
tools: this.tools,
};

@@ -331,2 +360,3 @@

safety_settings?: SafetySetting[];
tools?: Tool[];
private _vertex_instance: VertexAI_Preview;

@@ -348,3 +378,4 @@ private _use_non_stream = false;

generation_config?: GenerationConfig,
safety_settings?: SafetySetting[]
safety_settings?: SafetySetting[],
tools?: Tool[]
) {

@@ -355,2 +386,3 @@ this._vertex_instance = vertex_instance;

this.safety_settings = safety_settings;
this.tools = tools;
if (model.startsWith('models/')) {

@@ -377,3 +409,3 @@ this.publisherModelEndpoint = `publishers/google/${this.model}`;

validateGcsInput(request.contents);
validateGenerateContentRequest(request);

@@ -401,2 +433,3 @@ if (request.generation_config) {

safety_settings: request.safety_settings ?? this.safety_settings,
tools: request.tools ?? [],
};

@@ -433,3 +466,3 @@

);
validateGcsInput(request.contents);
validateGenerateContentRequest(request);

@@ -446,2 +479,3 @@ if (request.generation_config) {

safety_settings: request.safety_settings ?? this.safety_settings,
tools: request.tools ?? [],
};

@@ -504,2 +538,3 @@ const response = await postRequest({

request.safety_settings ?? this.safety_settings;
startChatRequest.tools = request.tools ?? this.tools;
}

@@ -510,3 +545,5 @@ return new ChatSession(startChatRequest);

function formulateNewContent(request: string | Array<string | Part>): Content {
function formulateNewContentFromSendMessageRequest(
request: string | Array<string | Part>
): Content[] {
let newParts: Part[] = [];

@@ -526,6 +563,47 @@

const newContent: Content = {role: constants.USER_ROLE, parts: newParts};
return newContent;
return assignRoleToPartsAndValidateSendMessageRequest(newParts);
}
/**
* When multiple Part types (i.e. FunctionResponsePart and TextPart) are
* passed in a single Part array, we may need to assign different roles to each
* part. Currently only FunctionResponsePart requires a role other than 'user'.
* @ignore
* @param {Array<Part>} parts Array of parts to pass to the model
* @return {Content[]} Array of content items
*/
function assignRoleToPartsAndValidateSendMessageRequest(
parts: Array<Part>
): Content[] {
const userContent: Content = {role: constants.USER_ROLE, parts: []};
const functionContent: Content = {role: constants.FUNCTION_ROLE, parts: []};
let hasUserContent = false;
let hasFunctionContent = false;
for (const part of parts) {
if ('functionResponse' in part) {
functionContent.parts.push(part);
hasFunctionContent = true;
} else {
userContent.parts.push(part);
hasUserContent = true;
}
}
if (hasUserContent && hasFunctionContent) {
throw new ClientError(
'Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.'
);
}
if (!hasUserContent && !hasFunctionContent) {
throw new ClientError('No content is provided for sending chat message.');
}
if (hasUserContent) {
return [userContent];
}
return [functionContent];
}
function throwErrorIfNotOK(response: Response | undefined) {

@@ -562,2 +640,23 @@ if (response === undefined) {

function validateFunctionResponseRequest(contents: Content[]) {
const lastestContentPart = contents[contents.length - 1].parts[0];
if (!('functionResponse' in lastestContentPart)) {
return;
}
const errorMessage =
'Please ensure that function response turn comes immediately after a function call turn.';
if (contents.length < 2) {
throw new ClientError(errorMessage);
}
const secondLastestContentPart = contents[contents.length - 2].parts[0];
if (!('functionCall' in secondLastestContentPart)) {
throw new ClientError(errorMessage);
}
}
function validateGenerateContentRequest(request: GenerateContentRequest) {
validateGcsInput(request.contents);
validateFunctionResponseRequest(request.contents);
}
function validateGenerationConfig(

@@ -564,0 +663,0 @@ generation_config: GenerationConfig

@@ -196,2 +196,9 @@ /**

}
if (part.functionCall) {
aggregatedResponse.candidates[i].content.parts[0].functionCall =
part.functionCall;
// the empty 'text' key should be removed if functionCall is in the
// response
delete aggregatedResponse.candidates[i].content.parts[0].text;
}
}

@@ -198,0 +205,0 @@ }

@@ -81,2 +81,3 @@ /**

generation_config?: GenerationConfig;
tools?: Tool[];
}

@@ -209,2 +210,4 @@

* 3. file_data
* 4. functionResponse
* 5. functionCall
*/

@@ -220,2 +223,6 @@ // TODO: Adjust so one must be true.

* @property {never} - [file_data]. file_data is not expected for TextPart.
* @property {never} - [functionResponse]. functionResponse is not expected for
* TextPart.
* @property {never} - [functionCall]. functionCall is not expected for
* TextPart.
*

@@ -227,2 +234,4 @@ */

file_data?: never;
functionResponse?: never;
functionCall?: never;
}

@@ -233,4 +242,11 @@

* @property {never} - [text]. text is not expected for InlineDataPart.
* @property {GenerativeContentBlob} - inline_data. Only this property is expected for InlineDataPart. {@link GenerativeContentBlob}
* @property {never} - [file_data]. file_data is not expected for InlineDataPart.
* @property {GenerativeContentBlob} - inline_data. Only this property is
* expected for InlineDataPart. {@link GenerativeContentBlob}
* @property {never} - [file_data]. file_data is not expected for
* InlineDataPart.
* @property {never} - [functionResponse]. functionResponse is not expected for
* InlineDataPart.
* @property {never} - [functionCall]. functionCall is not expected for
* InlineDataPart.
*
*/

@@ -241,2 +257,4 @@ export interface InlineDataPart extends BasePart {

file_data?: never;
functionResponse?: never;
functionCall?: never;
}

@@ -257,4 +275,11 @@

* @property {never} - [text]. text is not expected for FileDataPart.
* @property {never} - [inline_data]. inline_data is not expected for FileDataPart.
* @property {FileData} - file_data. Only this property is expected for FileDataPart. {@link FileData}
* @property {never} - [inline_data]. inline_data is not expected for
* FileDataPart.
* @property {FileData} - file_data. Only this property is expected for
* FileDataPart. {@link FileData}
* @property {never} - [functionResponse]. functionResponse is not expected for
* FileDataPart.
* @property {never} - [functionCall]. functionCall is not expected for
* FileDataPart.
*
*/

@@ -265,13 +290,64 @@ export interface FileDataPart extends BasePart {

file_data: FileData;
functionResponse?: never;
functionCall?: never;
}
/**
* A datatype containing media that is part of a multi-part {@link Content} message.
* A `Part` is a union type of {@link TextPart}, {@link InlineDataPart} and {@link FileDataPart}
* A `Part` has one of the following mutually exclusive fields:
* A function response part of a conversation with the model.
* @property {never} - [text]. text is not expected for FunctionResponsePart.
* @property {never} - [inline_data]. inline_data is not expected for
* FunctionResponsePart.
* @property {FileData} - [file_data]. file_data is not expected for
* FunctionResponsePart. {@link FileData}
* @property {never} - functionResponse. only functionResponse is expected for
* FunctionResponsePart.
* @property {never} - [functionCall]. functionCall is not expected for
* FunctionResponsePart.
*
*/
export interface FunctionResponsePart extends BasePart {
text?: never;
inline_data?: never;
file_data?: never;
functionResponse: FunctionResponse;
functionCall?: never;
}
/**
* A function call part of a conversation with the model.
* @property {never} - [text]. text is not expected for FunctionResponsePart.
* @property {never} - [inline_data]. inline_data is not expected for
* FunctionResponsePart.
* @property {never} - [file_data]. file_data is not expected for
* FunctionResponsePart. {@link FileData}
* @property {never} - [functionResponse]. functionResponse is not expected for
* FunctionResponsePart.
* @property {FunctionCall} - functionCall. only functionCall is expected for
* FunctionCallPart.
*
*/
export interface FunctionCallPart extends BasePart {
text?: never;
inline_data?: never;
file_data?: never;
functionResponse?: never;
functionCall: FunctionCall;
}
/**
* A datatype containing media that is part of a multi-part {@link Content}
* message. A `Part` is a union type of {@link TextPart}, {@link
* InlineDataPart}, {@link FileDataPart}, and {@link FunctionResponsePart}. A
* `Part` has one of the following mutually exclusive fields:
* 1. text
* 2. inline_data
* 3. file_data
* 4. functionResponse
*/
export declare type Part = TextPart | InlineDataPart | FileDataPart;
export declare type Part =
| TextPart
| InlineDataPart
| FileDataPart
| FunctionResponsePart
| FunctionCallPart;

@@ -410,2 +486,3 @@ /**

citationMetadata?: CitationMetadata;
functionCall?: FunctionCall;
}

@@ -434,1 +511,121 @@

}
/**
* A predicted FunctionCall returned from the model that contains a string
* representating the FunctionDeclaration.name with the parameters and their
* values.
* @property {string} - name The name of the function specified in
* FunctionDeclaration.name.
* @property {object} - args The arguments to pass to the function.
*/
export declare interface FunctionCall {
name: string;
args: object;
}
/**
* The result output of a FunctionCall that contains a string representing
* the FunctionDeclaration.name and a structured JSON object containing any
* output from the function call. It is used as context to the model.
* @property {string} - name The name of the function specified in
* FunctionDeclaration.name.
* @property {object} - response The expected response from the model.
*/
export declare interface FunctionResponse {
name: string;
response: object;
}
/**
* Structured representation of a function declaration as defined by the
* [OpenAPI 3.0 specification](https://spec.openapis.org/oas/v3.0.3). Included
* in this declaration are the function name and parameters. This
* FunctionDeclaration is a representation of a block of code that can be used
* as a Tool by the model and executed by the client.
* @property {string} - name The name of the function to call. Must start with a
* letter or an underscore. Must be a-z, A-Z, 0-9, or contain underscores and
* dashes, with a max length of 64.
* @property {string} - description Description and purpose of the function.
* Model uses it to decide how and whether to call the function.
* @property {FunctionDeclarationSchema} - parameters Describes the parameters
* to this function in JSON Schema Object format. Reflects the Open API 3.03
* Parameter Object. string Key: the name of the parameter. Parameter names are
* case sensitive. Schema Value: the Schema defining the type used for the
* parameter. For function with no parameters, this can be left unset. Example
* with 1 required and 1 optional parameter: type: OBJECT properties:
param1:
type: STRING
param2:
type: INTEGER
required:
- param1
*/
export declare interface FunctionDeclaration {
name: string;
description?: string;
parameters?: FunctionDeclarationSchema;
}
/**
* A Tool is a piece of code that enables the system to interact with
* external systems to perform an action, or set of actions, outside of
* knowledge and scope of the model.
* @property {object} - function_declarations One or more function declarations
* to be passed to the model along with the current user query. Model may decide
* to call a subset of these functions by populating
* [FunctionCall][content.part.function_call] in the response. User should
* provide a [FunctionResponse][content.part.function_response] for each
* function call in the next turn. Based on the function responses, Model will
* generate the final response back to the user. Maximum 64 function
* declarations can be provided.
*/
export declare interface Tool {
function_declarations: FunctionDeclaration[];
}
/**
* Contains the list of OpenAPI data types
* as defined by https://swagger.io/docs/specification/data-models/data-types/
* @public
*/
export enum FunctionDeclarationSchemaType {
STRING = 'STRING',
NUMBER = 'NUMBER',
INTEGER = 'INTEGER',
BOOLEAN = 'BOOLEAN',
ARRAY = 'ARRAY',
OBJECT = 'OBJECT',
}
/**
* Schema for parameters passed to [FunctionDeclaration.parameters]
* @public
*/
export interface FunctionDeclarationSchema {
type: FunctionDeclarationSchemaType;
properties: {[k: string]: FunctionDeclarationSchemaProperty};
description?: string;
required?: string[];
}
/**
* Schema is used to define the format of input/output data.
* Represents a select subset of an OpenAPI 3.0 schema object.
* More fields may be added in the future as needed.
* @public
*/
export interface FunctionDeclarationSchemaProperty {
type?: FunctionDeclarationSchemaType;
format?: string;
description?: string;
nullable?: boolean;
items?: FunctionDeclarationSchema;
enum?: string[];
properties?: {[k: string]: FunctionDeclarationSchema};
required?: string[];
example?: unknown;
}

@@ -21,5 +21,6 @@ /**

export const MODEL_ROLE = 'model';
export const FUNCTION_ROLE = 'function';
const USER_AGENT_PRODUCT = 'model-builder';
const CLIENT_LIBRARY_VERSION = '0.2.1'; // x-release-please-version
const CLIENT_LIBRARY_VERSION = '0.3.0'; // x-release-please-version
const CLIENT_LIBRARY_LANGUAGE = `grpc-node/${CLIENT_LIBRARY_VERSION}`;
export const USER_AGENT = `${USER_AGENT_PRODUCT}/${CLIENT_LIBRARY_VERSION} ${CLIENT_LIBRARY_LANGUAGE}`;

@@ -19,7 +19,5 @@ /**

// @ts-ignore
import * as assert from 'assert';
import {ClientError, TextPart, VertexAI} from '../src';
import {FunctionDeclarationSchemaType} from '../src/types';
import {ClientError, VertexAI, TextPart} from '../src';
// TODO: this env var isn't getting populated correctly
const PROJECT = process.env.GCLOUD_PROJECT;

@@ -37,3 +35,3 @@ const LOCATION = 'us-central1';

file_data: {
file_uri: 'gs://nodejs_vertex_system_test_resources/scones.jpg',
file_uri: 'gs://generativeai-downloads/images/scones.jpg',
mime_type: 'image/jpeg',

@@ -58,4 +56,48 @@ },

const FUNCTION_CALL_NAME = 'get_current_weather';
const TOOLS_WITH_FUNCTION_DECLARATION = [
{
function_declarations: [
{
name: FUNCTION_CALL_NAME,
description: 'get weather in a given location',
parameters: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
location: {type: FunctionDeclarationSchemaType.STRING},
unit: {
type: FunctionDeclarationSchemaType.STRING,
enum: ['celsius', 'fahrenheit'],
},
},
required: ['location'],
},
},
],
},
];
const WEATHER_FORECAST = 'super nice';
const FUNCTION_RESPONSE_PART = [
{
functionResponse: {
name: FUNCTION_CALL_NAME,
response: {
name: FUNCTION_CALL_NAME,
content: {weather: WEATHER_FORECAST},
},
},
},
];
const FUNCTION_CALL = [
{functionCall: {name: FUNCTION_CALL_NAME, args: {location: 'boston'}}},
];
// Initialize Vertex with your Cloud project and location
const vertex_ai = new VertexAI({project: 'long-door-651', location: LOCATION});
const vertex_ai = new VertexAI({
project: PROJECT as string,
location: LOCATION,
});

@@ -85,7 +127,5 @@ const generativeTextModel = vertex_ai.preview.getGenerativeModel({

// TODO (b/316599049): update tests to use jasmine expect syntax:
// expect(...).toBeInstanceOf(...)
describe('generateContentStream', () => {
beforeEach(() => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000;
jasmine.DEFAULT_TIMEOUT_INTERVAL = 30000;
});

@@ -98,4 +138,3 @@

for await (const item of streamingResp.stream) {
assert(
item.candidates[0],
expect(item.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream, for item ${item}`

@@ -106,4 +145,3 @@ );

const aggregatedResp = await streamingResp.response;
assert(
aggregatedResp.candidates[0],
expect(aggregatedResp.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`

@@ -118,4 +156,3 @@ );

for await (const item of streamingResp.stream) {
assert(
item.candidates[0],
expect(item.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream, for item ${item}`

@@ -125,4 +162,4 @@ );

for (const part of candidate.content.parts as TextPart[]) {
assert(
!part.text.includes('\ufffd'),
expect(part.text).not.toContain(
'\ufffd',
`sys test failure on generateContentStream, for item ${item}`

@@ -135,4 +172,3 @@ );

const aggregatedResp = await streamingResp.response;
assert(
aggregatedResp.candidates[0],
expect(aggregatedResp.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`

@@ -147,4 +183,3 @@ );

for await (const item of streamingResp.stream) {
assert(
item.candidates[0],
expect(item.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream, for item ${item}`

@@ -155,4 +190,3 @@ );

const aggregatedResp = await streamingResp.response;
assert(
aggregatedResp.candidates[0],
expect(aggregatedResp.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`

@@ -174,33 +208,79 @@ );

await generativeVisionModel.generateContentStream(badRequest).catch(e => {
assert(
e instanceof ClientError,
`sys test failure on generateContentStream when having bad request should throw ClientError but actually thrown ${e}`
expect(e).toBeInstanceOf(ClientError);
expect(e.message).toBe(
'[VertexAI.ClientError]: got status: 400 Bad Request',
`sys test failure on generateContentStream when having bad request
got wrong error message: ${e.message}`
);
assert(
e.message === '[VertexAI.ClientError]: got status: 400 Bad Request',
`sys test failure on generateContentStream when having bad request got wrong error message: ${e.message}`
);
});
});
// TODO: this is returning a 500 on the system test project
// it('should should return a stream and aggregated response when passed
// multipart GCS content',
// async () => {
// const streamingResp = await
// generativeVisionModel.generateContentStream(
// MULTI_PART_GCS_REQUEST);
// for await (const item of streamingResp.stream) {
// assert(item.candidates[0]);
// console.log('stream chunk: ', item);
// }
it('should should return a stream and aggregated response when passed multipart GCS content', async () => {
const streamingResp = await generativeVisionModel.generateContentStream(
MULTI_PART_GCS_REQUEST
);
// const aggregatedResp = await streamingResp.response;
// assert(aggregatedResp.candidates[0]);
// console.log('aggregated response: ', aggregatedResp);
// });
for await (const item of streamingResp.stream) {
expect(item.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream, for item ${item}`
);
}
const aggregatedResp = await streamingResp.response;
expect(aggregatedResp.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`
);
});
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => {
const request = {
contents: [
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]},
{role: 'model', parts: FUNCTION_CALL},
{role: 'function', parts: FUNCTION_RESPONSE_PART},
],
tools: TOOLS_WITH_FUNCTION_DECLARATION,
};
const streamingResp =
await generativeTextModel.generateContentStream(request);
for await (const item of streamingResp.stream) {
expect(item.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream, for item ${item}`
);
expect(item.candidates[0].content.parts[0].text?.toLowerCase()).toContain(
WEATHER_FORECAST
);
}
});
});
// TODO (b/316599049): add tests for generateContent and sendMessage
describe('generateContent', () => {
beforeEach(() => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000;
});
it('should return the aggregated response', async () => {
const response = await generativeTextModel.generateContent(TEXT_REQUEST);
const aggregatedResp = await response.response;
expect(aggregatedResp.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`
);
});
});
describe('sendMessage', () => {
beforeEach(() => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000;
});
it('should populate history and return a chat response', async () => {
const chat = generativeTextModel.startChat();
const chatInput1 = 'How can I learn more about Node.js?';
const result1 = await chat.sendMessage(chatInput1);
const response1 = await result1.response;
expect(response1.candidates[0]).toBeTruthy(
`sys test failure on sendMessage for aggregated response: ${response1}`
);
expect(chat.history.length).toBe(2);
});
});
describe('sendMessageStream', () => {

@@ -219,4 +299,3 @@ beforeEach(() => {

for await (const item of result1.stream) {
assert(
item.candidates[0],
expect(item.candidates[0]).toBeTruthy(
`sys test failure on sendMessageStream, for item ${item}`

@@ -226,4 +305,3 @@ );

const resp = await result1.response;
assert(
resp.candidates[0],
expect(resp.candidates[0]).toBeTruthy(
`sys test failure on sendMessageStream for aggregated response: ${resp}`

@@ -238,4 +316,3 @@ );

for await (const item of result1.stream) {
assert(
item.candidates[0],
expect(item.candidates[0]).toBeTruthy(
`sys test failure on sendMessageStream, for item ${item}`

@@ -245,4 +322,3 @@ );

const resp = await result1.response;
assert(
resp.candidates[0],
expect(resp.candidates[0]).toBeTruthy(
`sys test failure on sendMessageStream for aggregated response: ${resp}`

@@ -272,2 +348,33 @@ );

});
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => {
const chat = generativeTextModel.startChat({
tools: TOOLS_WITH_FUNCTION_DECLARATION,
});
const chatInput1 = 'What is the weather in Boston?';
const result1 = await chat.sendMessageStream(chatInput1);
for await (const item of result1.stream) {
expect(item.candidates[0]).toBeTruthy(
`sys test failure on sendMessageStream with function calling, for item ${item}`
);
}
const response1 = await result1.response;
expect(
JSON.stringify(response1.candidates[0].content.parts[0].functionCall)
).toContain(FUNCTION_CALL_NAME);
expect(
JSON.stringify(response1.candidates[0].content.parts[0].functionCall)
).toContain('location');
// Send a follow up message with a FunctionResponse
const result2 = await chat.sendMessageStream(FUNCTION_RESPONSE_PART);
for await (const item of result2.stream) {
expect(item.candidates[0]).toBeTruthy(
`sys test failure on sendMessageStream with function calling, for item ${item}`
);
}
const response2 = await result2.response;
expect(
JSON.stringify(response2.candidates[0].content.parts[0].text)
).toContain(WEATHER_FORECAST);
});
});

@@ -278,4 +385,3 @@

const countTokensResp = await generativeTextModel.countTokens(TEXT_REQUEST);
assert(
countTokensResp.totalTokens,
expect(countTokensResp.totalTokens).toBeTruthy(
`sys test failure on countTokens, ${countTokensResp}`

@@ -288,3 +394,3 @@ );

beforeEach(() => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 25000;
jasmine.DEFAULT_TIMEOUT_INTERVAL = 30000;
});

@@ -297,4 +403,3 @@

for await (const item of streamingResp.stream) {
assert(
item.candidates[0],
expect(item.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream using models/gemini-pro, for item ${item}`

@@ -305,4 +410,3 @@ );

const aggregatedResp = await streamingResp.response;
assert(
aggregatedResp.candidates[0],
expect(aggregatedResp.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream using models/gemini-pro for aggregated response: ${aggregatedResp}`

@@ -319,4 +423,3 @@ );

for await (const item of streamingResp.stream) {
assert(
item.candidates[0],
expect(item.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream using models/gemini-pro-vision, for item ${item}`

@@ -327,4 +430,3 @@ );

const aggregatedResp = await streamingResp.response;
assert(
aggregatedResp.candidates[0],
expect(aggregatedResp.candidates[0]).toBeTruthy(
`sys test failure on generateContentStream using models/gemini-pro-vision for aggregated response: ${aggregatedResp}`

@@ -331,0 +433,0 @@ );

@@ -30,2 +30,3 @@ /**

FinishReason,
FunctionDeclarationSchemaType,
GenerateContentRequest,

@@ -40,3 +41,5 @@ GenerateContentResponse,

StreamGenerateContentResult,
Tool,
} from '../src/types/content';
import {GoogleAuthError} from '../src/types/errors';
import {constants} from '../src/util';

@@ -119,2 +122,37 @@

};
const TEST_FUNCTION_CALL_RESPONSE = {
functionCall: {
name: 'get_current_weather',
args: {
location: 'LA',
unit: 'fahrenheit',
},
},
};
const TEST_CANDIDATES_WITH_FUNCTION_CALL = [
{
index: 1,
content: {
role: constants.MODEL_ROLE,
parts: [TEST_FUNCTION_CALL_RESPONSE],
},
finishReason: FinishReason.STOP,
finishMessage: '',
safetyRatings: TEST_SAFETY_RATINGS,
},
];
const TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL = {
candidates: TEST_CANDIDATES_WITH_FUNCTION_CALL,
};
const TEST_FUNCTION_RESPONSE_PART = [
{
functionResponse: {
name: 'get_current_weather',
response: {name: 'get_current_weather', content: {weather: 'super nice'}},
},
},
];
const TEST_CANDIDATES_MISSING_ROLE = [

@@ -138,4 +176,2 @@ {

const TEST_ENDPOINT_BASE_PATH = 'test.googleapis.com';
const TEST_FILENAME = '/tmp/image.jpeg';
const INVALID_FILENAME = 'image.txt';
const TEST_GCS_FILENAME = 'gs://test_bucket/test_image.jpeg';

@@ -153,2 +189,40 @@

const BASE_64_IMAGE =
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==';
const INLINE_DATA_FILE_PART = {
inline_data: {
data: BASE_64_IMAGE,
mime_type: 'image/jpeg',
},
};
const TEST_MULTIPART_MESSAGE_BASE64 = [
{
role: constants.USER_ROLE,
parts: [{text: 'What is in this picture?'}, INLINE_DATA_FILE_PART],
},
];
const TEST_TOOLS_WITH_FUNCTION_DECLARATION: Tool[] = [
{
function_declarations: [
{
name: 'get_current_weather',
description: 'get weather in a given location',
parameters: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
location: {type: FunctionDeclarationSchemaType.STRING},
unit: {
type: FunctionDeclarationSchemaType.STRING,
enum: ['celsius', 'fahrenheit'],
},
},
required: ['location'],
},
},
],
},
];
const fetchResponseObj = {

@@ -172,2 +246,8 @@ status: 200,

export async function* testGeneratorWithEmptyResponse(): AsyncGenerator<GenerateContentResponse> {
yield {
candidates: [],
};
}
describe('VertexAI', () => {

@@ -203,3 +283,3 @@ let vertexai: VertexAI;

const googleAuthOptions = {
scopes: 'test.scopes',
scopes: 'https://www.googleapis.com/auth/cloud-platform',
};

@@ -231,2 +311,37 @@ const vetexai1 = new VertexAI({

it('given scopes missing required scope, should throw GoogleAuthError', () => {
const invalidGoogleAuthOptionsStringScopes = {scopes: 'test.scopes'};
expect(() => {
new VertexAI({
project: PROJECT,
location: LOCATION,
googleAuthOptions: invalidGoogleAuthOptionsStringScopes,
});
}).toThrow(
new GoogleAuthError(
"input GoogleAuthOptions.scopes test.scopes doesn't contain required scope " +
'https://www.googleapis.com/auth/cloud-platform, ' +
'please include https://www.googleapis.com/auth/cloud-platform into GoogleAuthOptions.scopes ' +
'or leave GoogleAuthOptions.scopes undefined'
)
);
const invalidGoogleAuthOptionsArrayScopes = {
scopes: ['test1.scopes', 'test2.scopes'],
};
expect(() => {
new VertexAI({
project: PROJECT,
location: LOCATION,
googleAuthOptions: invalidGoogleAuthOptionsArrayScopes,
});
}).toThrow(
new GoogleAuthError(
"input GoogleAuthOptions.scopes test1.scopes,test2.scopes doesn't contain required scope " +
'https://www.googleapis.com/auth/cloud-platform, ' +
'please include https://www.googleapis.com/auth/cloud-platform into GoogleAuthOptions.scopes ' +
'or leave GoogleAuthOptions.scopes undefined'
)
);
});
describe('generateContent', () => {

@@ -256,5 +371,3 @@ it('returns a GenerateContentResponse', async () => {

});
});
describe('generateContent', () => {
it('returns a GenerateContentResponse when passed a GCS URI', async () => {

@@ -277,5 +390,3 @@ const req: GenerateContentRequest = {

});
});
describe('generateContent', () => {
it('raises an error when passed an invalid GCS URI', async () => {

@@ -289,5 +400,3 @@ const req: GenerateContentRequest = {

});
});
describe('generateContent', () => {
it('returns a GenerateContentResponse when passed safety_settings and generation_config', async () => {

@@ -312,5 +421,3 @@ const req: GenerateContentRequest = {

});
});
describe('generateContent', () => {
it('updates the base API endpoint when provided', async () => {

@@ -347,5 +454,3 @@ const vertexaiWithBasePath = new VertexAI({

});
});
describe('generateContent', () => {
it('default the base API endpoint when base API not provided', async () => {

@@ -383,5 +488,3 @@ const vertexaiWithoutBasePath = new VertexAI({

});
});
describe('generateContent', () => {
it('removes top_k when it is set to 0', async () => {

@@ -400,6 +503,2 @@ const reqWithEmptyConfigs: GenerateContentRequest = {

};
// const fetchResult = Promise.resolve(
// new Response(JSON.stringify(expectedStreamResult),
// fetchResponseObj));
// const requestSpy = spyOn(global, 'fetch').and.returnValue(fetchResult);
spyOn(StreamFunctions, 'processStream').and.returnValue(

@@ -414,5 +513,3 @@ expectedStreamResult

});
});
describe('generateContent', () => {
it('includes top_k when it is within 1 - 40', async () => {

@@ -440,5 +537,3 @@ const reqWithEmptyConfigs: GenerateContentRequest = {

});
});
describe('generateContent', () => {
it('aggregates citation metadata', async () => {

@@ -466,4 +561,60 @@ const req: GenerateContentRequest = {

});
it('returns a FunctionCall when passed a FunctionDeclaration', async () => {
const req: GenerateContentRequest = {
contents: [
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]},
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedResult: GenerateContentResult = {
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL,
};
const expectedStreamResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL),
stream: testGenerator(),
};
spyOn(StreamFunctions, 'processStream').and.returnValue(
expectedStreamResult
);
const resp = await model.generateContent(req);
expect(resp).toEqual(expectedResult);
});
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => {
const req: GenerateContentRequest = {
contents: [
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]},
{
role: 'function',
parts: TEST_FUNCTION_RESPONSE_PART,
},
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedErrorMessage =
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.';
await model.generateContent(req).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => {
const req: GenerateContentRequest = {
contents: [
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]},
{
role: 'function',
parts: TEST_FUNCTION_RESPONSE_PART,
},
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedErrorMessage =
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.';
await model.generateContent(req).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
});
describe('generateContentStream', () => {

@@ -482,2 +633,3 @@ it('returns a GenerateContentResponse when passed text content', async () => {

});
it('returns a GenerateContentResponse when passed a string', async () => {

@@ -492,5 +644,3 @@ const expectedResult: StreamGenerateContentResult = {

});
});
describe('generateContentStream', () => {
it('returns a GenerateContentResponse when passed multi-part content with a GCS URI', async () => {

@@ -508,7 +658,69 @@ const req: GenerateContentRequest = {

});
it('returns a GenerateContentResponse when passed multi-part content with base64 data', async () => {
const req: GenerateContentRequest = {
contents: TEST_MULTIPART_MESSAGE_BASE64,
};
const expectedResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult);
const resp = await model.generateContentStream(req);
expect(resp).toEqual(expectedResult);
});
it('returns a FunctionCall when passed a FunctionDeclaration', async () => {
const req: GenerateContentRequest = {
contents: [
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]},
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedStreamResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL),
stream: testGenerator(),
};
spyOn(StreamFunctions, 'processStream').and.returnValue(
expectedStreamResult
);
const resp = await model.generateContentStream(req);
expect(resp).toEqual(expectedStreamResult);
});
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => {
const req: GenerateContentRequest = {
contents: [
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]},
{
role: 'function',
parts: TEST_FUNCTION_RESPONSE_PART,
},
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedErrorMessage =
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.';
await model.generateContentStream(req).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => {
const req: GenerateContentRequest = {
contents: [
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]},
{
role: 'function',
parts: TEST_FUNCTION_RESPONSE_PART,
},
],
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
};
const expectedErrorMessage =
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.';
await model.generateContentStream(req).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
});
// TODO: add a streaming test with a multipart message and inline image data
// (b64 string)
describe('startChat', () => {

@@ -556,2 +768,4 @@ it('returns a ChatSession when passed a request arg', () => {

let chatSessionWithNoArgs: ChatSession;
let chatSessionWithEmptyResponse: ChatSession;
let chatSessionWithFunctionCall: ChatSession;
let vertexai: VertexAI;

@@ -570,2 +784,6 @@ let model: GenerativeModel;

chatSessionWithNoArgs = model.startChat();
chatSessionWithEmptyResponse = model.startChat();
chatSessionWithFunctionCall = model.startChat({
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION,
});
expectedStreamResult = {

@@ -616,9 +834,3 @@ response: Promise.resolve(TEST_MODEL_RESPONSE),

// TODO: unbreak this test. Currently chatSession.history is saving the
// history from the test above instead of resetting and
// expect.toThrowError() is erroring out before the expect condition is
// called
it('throws an error when the model returns an empty response', async () => {
// Reset the chat session history
const req = 'How are you doing today?';

@@ -629,2 +841,19 @@ const expectedResult: GenerateContentResult = {

const expectedStreamResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_EMPTY_MODEL_RESPONSE),
stream: testGeneratorWithEmptyResponse(),
};
spyOn(StreamFunctions, 'processStream').and.returnValue(
expectedStreamResult
);
await expectAsync(
chatSessionWithEmptyResponse.sendMessage(req)
).toBeRejected();
expect(chatSessionWithEmptyResponse.history.length).toEqual(0);
});
it('returns a GenerateContentResponse when passed multi-part content', async () => {
const req = TEST_MULTIPART_MESSAGE[0]['parts'];
const expectedResult: GenerateContentResult = {
response: TEST_MODEL_RESPONSE,
};
const expectedStreamResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),

@@ -636,12 +865,64 @@ stream: testGenerator(),

);
// Shouldn't append anything to history with an empty result
// expect(chatSession.history.length).toEqual(1);
// expect(await chatSession.sendMessage(req))
// .toThrowError('Did not get a response from the model');
const resp = await chatSessionWithNoArgs.sendMessage(req);
expect(resp).toEqual(expectedResult);
console.log(chatSessionWithNoArgs.history, 'hihii');
expect(chatSessionWithNoArgs.history.length).toEqual(2);
});
// TODO: add test cases for different content types passed to
// sendMessage
it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => {
const functionCallChatMessage = 'What is the weather in LA?';
const expectedFunctionCallResponse: GenerateContentResult = {
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL,
};
const expectedStreamResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL),
stream: testGenerator(),
};
const streamSpy = spyOn(StreamFunctions, 'processStream');
streamSpy.and.returnValue(expectedStreamResult);
const response1 = await chatSessionWithFunctionCall.sendMessage(
functionCallChatMessage
);
expect(response1).toEqual(expectedFunctionCallResponse);
expect(chatSessionWithFunctionCall.history.length).toEqual(2);
// Send a follow-up message with a FunctionResponse
const expectedFollowUpResponse: GenerateContentResult = {
response: TEST_MODEL_RESPONSE,
};
const expectedFollowUpStreamResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
streamSpy.and.returnValue(expectedFollowUpStreamResult);
const response2 = await chatSessionWithFunctionCall.sendMessage(
TEST_FUNCTION_RESPONSE_PART
);
expect(response2).toEqual(expectedFollowUpResponse);
expect(chatSessionWithFunctionCall.history.length).toEqual(4);
});
it('throw ClientError when request has no content', async () => {
const expectedErrorMessage =
'[VertexAI.ClientError]: No content is provided for sending chat message.';
await chatSessionWithNoArgs.sendMessage([]).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
it('throw ClientError when request mix functionCall part with other types of part', async () => {
const chatRequest = [
'what is the weather like in LA',
TEST_FUNCTION_RESPONSE_PART[0],
];
const expectedErrorMessage =
'[VertexAI.ClientError]: Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.';
await chatSessionWithNoArgs.sendMessage(chatRequest).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
});
describe('sendMessageStram', () => {
describe('sendMessageStream', () => {
it('returns a StreamGenerateContentResponse and appends to history', async () => {

@@ -699,2 +980,51 @@ const req = 'How are you doing today?';

});
it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => {
const functionCallChatMessage = 'What is the weather in LA?';
const expectedStreamResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL),
stream: testGenerator(),
};
const streamSpy = spyOn(StreamFunctions, 'processStream');
streamSpy.and.returnValue(expectedStreamResult);
const response1 = await chatSessionWithFunctionCall.sendMessageStream(
functionCallChatMessage
);
expect(response1).toEqual(expectedStreamResult);
expect(chatSessionWithFunctionCall.history.length).toEqual(2);
// Send a follow-up message with a FunctionResponse
const expectedFollowUpStreamResult: StreamGenerateContentResult = {
response: Promise.resolve(TEST_MODEL_RESPONSE),
stream: testGenerator(),
};
streamSpy.and.returnValue(expectedFollowUpStreamResult);
const response2 = await chatSessionWithFunctionCall.sendMessageStream(
TEST_FUNCTION_RESPONSE_PART
);
expect(response2).toEqual(expectedFollowUpStreamResult);
expect(chatSessionWithFunctionCall.history.length).toEqual(4);
});
it('throw ClientError when request has no content', async () => {
const expectedErrorMessage =
'[VertexAI.ClientError]: No content is provided for sending chat message.';
await chatSessionWithNoArgs.sendMessageStream([]).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
it('throw ClientError when request mix functionCall part with other types of part', async () => {
const chatRequest = [
'what is the weather like in LA',
TEST_FUNCTION_RESPONSE_PART[0],
];
const expectedErrorMessage =
'[VertexAI.ClientError]: Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.';
await chatSessionWithNoArgs.sendMessageStream(chatRequest).catch(e => {
expect(e.message).toEqual(expectedErrorMessage);
});
});
});

@@ -701,0 +1031,0 @@ });

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is not supported yet

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc