@google-cloud/vertexai
Advanced tools
Comparing version 0.3.1 to 0.4.0
{ | ||
".": "0.3.1" | ||
".": "0.4.0" | ||
} |
@@ -23,2 +23,2 @@ /** | ||
*/ | ||
export declare function countTokens(location: string, project: string, publisherModelEndpoint: string, token: Promise<any>, apiEndpoint: string, request: CountTokensRequest): Promise<CountTokensResponse>; | ||
export declare function countTokens(location: string, project: string, publisherModelEndpoint: string, token: Promise<any>, request: CountTokensRequest, apiEndpoint?: string): Promise<CountTokensResponse>; |
@@ -28,3 +28,3 @@ "use strict"; | ||
*/ | ||
async function countTokens(location, project, publisherModelEndpoint, token, apiEndpoint, request) { | ||
async function countTokens(location, project, publisherModelEndpoint, token, request, apiEndpoint) { | ||
const response = await (0, post_request_1.postRequest)({ | ||
@@ -41,3 +41,5 @@ region: location, | ||
}); | ||
(0, post_fetch_processing_1.throwErrorIfNotOK)(response); | ||
await (0, post_fetch_processing_1.throwErrorIfNotOK)(response).catch(e => { | ||
throw e; | ||
}); | ||
return (0, post_fetch_processing_1.processCountTokenResponse)(response); | ||
@@ -44,0 +46,0 @@ } |
@@ -17,1 +17,16 @@ /** | ||
*/ | ||
/** | ||
* Make a async call to generate content. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
* @return The GenerateContentResponse object with the response candidates. | ||
*/ | ||
import { GenerateContentRequest, GenerateContentResult, GenerationConfig, SafetySetting, StreamGenerateContentResult } from '../types/content'; | ||
export declare function generateContent(location: string, project: string, publisherModelEndpoint: string, token: Promise<any>, request: GenerateContentRequest | string, apiEndpoint?: string, generation_config?: GenerationConfig, safety_settings?: SafetySetting[]): Promise<GenerateContentResult>; | ||
/** | ||
* Make an async stream request to generate content. The response will be | ||
* returned in stream. | ||
* @param {GenerateContentRequest} request - {@link GenerateContentRequest} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link | ||
* StreamGenerateContentResult} | ||
*/ | ||
export declare function generateContentStream(location: string, project: string, publisherModelEndpoint: string, token: Promise<any>, request: GenerateContentRequest | string, apiEndpoint?: string, generation_config?: GenerationConfig, safety_settings?: SafetySetting[]): Promise<StreamGenerateContentResult>; |
@@ -18,2 +18,80 @@ "use strict"; | ||
*/ | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.generateContentStream = exports.generateContent = void 0; | ||
const errors_1 = require("../types/errors"); | ||
const constants = require("../util/constants"); | ||
const post_fetch_processing_1 = require("./post_fetch_processing"); | ||
const post_request_1 = require("./post_request"); | ||
const pre_fetch_processing_1 = require("./pre_fetch_processing"); | ||
async function generateContent(location, project, publisherModelEndpoint, token, request, apiEndpoint, generation_config, safety_settings) { | ||
var _a, _b, _c; | ||
request = (0, pre_fetch_processing_1.formatContentRequest)(request, generation_config, safety_settings); | ||
(0, pre_fetch_processing_1.validateGenerateContentRequest)(request); | ||
if (request.generation_config) { | ||
request.generation_config = (0, pre_fetch_processing_1.validateGenerationConfig)(request.generation_config); | ||
} | ||
const generateContentRequest = { | ||
contents: request.contents, | ||
generation_config: (_a = request.generation_config) !== null && _a !== void 0 ? _a : generation_config, | ||
safety_settings: (_b = request.safety_settings) !== null && _b !== void 0 ? _b : safety_settings, | ||
tools: (_c = request.tools) !== null && _c !== void 0 ? _c : [], | ||
}; | ||
const apiVersion = request.tools ? 'v1beta1' : 'v1'; | ||
const response = await (0, post_request_1.postRequest)({ | ||
region: location, | ||
project: project, | ||
resourcePath: publisherModelEndpoint, | ||
resourceMethod: constants.GENERATE_CONTENT_METHOD, | ||
token: await token, | ||
data: generateContentRequest, | ||
apiEndpoint: apiEndpoint, | ||
apiVersion: apiVersion, | ||
}).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
await (0, post_fetch_processing_1.throwErrorIfNotOK)(response).catch(e => { | ||
throw e; | ||
}); | ||
return (0, post_fetch_processing_1.processNonStream)(response); | ||
} | ||
exports.generateContent = generateContent; | ||
/** | ||
* Make an async stream request to generate content. The response will be | ||
* returned in stream. | ||
* @param {GenerateContentRequest} request - {@link GenerateContentRequest} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link | ||
* StreamGenerateContentResult} | ||
*/ | ||
async function generateContentStream(location, project, publisherModelEndpoint, token, request, apiEndpoint, generation_config, safety_settings) { | ||
var _a, _b, _c; | ||
request = (0, pre_fetch_processing_1.formatContentRequest)(request, generation_config, safety_settings); | ||
(0, pre_fetch_processing_1.validateGenerateContentRequest)(request); | ||
if (request.generation_config) { | ||
request.generation_config = (0, pre_fetch_processing_1.validateGenerationConfig)(request.generation_config); | ||
} | ||
const generateContentRequest = { | ||
contents: request.contents, | ||
generation_config: (_a = request.generation_config) !== null && _a !== void 0 ? _a : generation_config, | ||
safety_settings: (_b = request.safety_settings) !== null && _b !== void 0 ? _b : safety_settings, | ||
tools: (_c = request.tools) !== null && _c !== void 0 ? _c : [], | ||
}; | ||
const apiVersion = request.tools ? 'v1beta1' : 'v1'; | ||
const response = await (0, post_request_1.postRequest)({ | ||
region: location, | ||
project: project, | ||
resourcePath: publisherModelEndpoint, | ||
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD, | ||
token: await token, | ||
data: generateContentRequest, | ||
apiEndpoint: apiEndpoint, | ||
apiVersion: apiVersion, | ||
}).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
await (0, post_fetch_processing_1.throwErrorIfNotOK)(response).catch(e => { | ||
throw e; | ||
}); | ||
return (0, post_fetch_processing_1.processStream)(response); | ||
} | ||
exports.generateContentStream = generateContentStream; | ||
//# sourceMappingURL=generate_content.js.map |
@@ -19,1 +19,2 @@ /** | ||
export { postRequest } from './post_request'; | ||
export { generateContent, generateContentStream } from './generate_content'; |
@@ -19,3 +19,3 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.postRequest = exports.countTokens = void 0; | ||
exports.generateContentStream = exports.generateContent = exports.postRequest = exports.countTokens = void 0; | ||
var count_tokens_1 = require("./count_tokens"); | ||
@@ -25,2 +25,5 @@ Object.defineProperty(exports, "countTokens", { enumerable: true, get: function () { return count_tokens_1.countTokens; } }); | ||
Object.defineProperty(exports, "postRequest", { enumerable: true, get: function () { return post_request_1.postRequest; } }); | ||
var generate_content_1 = require("./generate_content"); | ||
Object.defineProperty(exports, "generateContent", { enumerable: true, get: function () { return generate_content_1.generateContent; } }); | ||
Object.defineProperty(exports, "generateContentStream", { enumerable: true, get: function () { return generate_content_1.generateContentStream; } }); | ||
//# sourceMappingURL=index.js.map |
@@ -17,8 +17,30 @@ /** | ||
*/ | ||
import { CountTokensResponse } from '../types/content'; | ||
export declare function throwErrorIfNotOK(response: Response | undefined): void; | ||
import { CountTokensResponse, GenerateContentResult, StreamGenerateContentResult } from '../types/content'; | ||
export declare function throwErrorIfNotOK(response: Response | undefined): Promise<void>; | ||
/** | ||
* Process a response.body stream from the backend and return an | ||
* iterator that provides one complete GenerateContentResponse at a time | ||
* and a promise that resolves with a single aggregated | ||
* GenerateContentResponse. | ||
* | ||
* @param response - Response from a fetch call | ||
* @ignore | ||
*/ | ||
export declare function processStream(response: Response | undefined): Promise<StreamGenerateContentResult>; | ||
/** | ||
* Reads a raw stream from the fetch response and join incomplete | ||
* chunks, returning a new stream that provides a single complete | ||
* GenerateContentResponse in each iteration. | ||
* @ignore | ||
*/ | ||
export declare function getResponseStream<T>(inputStream: ReadableStream<string>): ReadableStream<T>; | ||
/** | ||
* Process model responses from generateContent | ||
* @ignore | ||
*/ | ||
export declare function processNonStream(response: any): Promise<GenerateContentResult>; | ||
/** | ||
* Process model responses from countTokens | ||
* @ignore | ||
*/ | ||
export declare function processCountTokenResponse(response: any): CountTokensResponse; | ||
export declare function processCountTokenResponse(response: any): Promise<CountTokensResponse>; |
"use strict"; | ||
/** | ||
* @license | ||
* Copyright 2023 Google LLC | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.processCountTokenResponse = exports.throwErrorIfNotOK = void 0; | ||
exports.processCountTokenResponse = exports.processNonStream = exports.getResponseStream = exports.processStream = exports.throwErrorIfNotOK = void 0; | ||
const errors_1 = require("../types/errors"); | ||
function throwErrorIfNotOK(response) { | ||
async function throwErrorIfNotOK(response) { | ||
if (response === undefined) { | ||
throw new errors_1.GoogleGenerativeAIError('response is undefined'); | ||
} | ||
const status = response.status; | ||
const statusText = response.statusText; | ||
const errorMessage = `got status: ${status} ${statusText}`; | ||
if (status >= 400 && status < 500) { | ||
throw new errors_1.ClientError(errorMessage); | ||
} | ||
if (!response.ok) { | ||
const status = response.status; | ||
const statusText = response.statusText; | ||
const errorBody = await response.json(); | ||
const errorMessage = `got status: ${status} ${statusText}. ${JSON.stringify(errorBody)}`; | ||
if (status >= 400 && status < 500) { | ||
throw new errors_1.ClientError(errorMessage); | ||
} | ||
throw new errors_1.GoogleGenerativeAIError(errorMessage); | ||
@@ -20,12 +37,185 @@ } | ||
exports.throwErrorIfNotOK = throwErrorIfNotOK; | ||
const responseLineRE = /^data: (.*)(?:\n\n|\r\r|\r\n\r\n)/; | ||
async function* generateResponseSequence(stream) { | ||
const reader = stream.getReader(); | ||
while (true) { | ||
const { value, done } = await reader.read(); | ||
if (done) { | ||
break; | ||
} | ||
yield value; | ||
} | ||
} | ||
/** | ||
* Process a response.body stream from the backend and return an | ||
* iterator that provides one complete GenerateContentResponse at a time | ||
* and a promise that resolves with a single aggregated | ||
* GenerateContentResponse. | ||
* | ||
* @param response - Response from a fetch call | ||
* @ignore | ||
*/ | ||
async function processStream(response) { | ||
if (response === undefined) { | ||
throw new Error('Error processing stream because response === undefined'); | ||
} | ||
if (!response.body) { | ||
throw new Error('Error processing stream because response.body not found'); | ||
} | ||
const inputStream = response.body.pipeThrough(new TextDecoderStream('utf8', { fatal: true })); | ||
const responseStream = getResponseStream(inputStream); | ||
const [stream1, stream2] = responseStream.tee(); | ||
return Promise.resolve({ | ||
stream: generateResponseSequence(stream1), | ||
response: getResponsePromise(stream2), | ||
}); | ||
} | ||
exports.processStream = processStream; | ||
async function getResponsePromise(stream) { | ||
const allResponses = []; | ||
const reader = stream.getReader(); | ||
// eslint-disable-next-line no-constant-condition | ||
while (true) { | ||
const { done, value } = await reader.read(); | ||
if (done) { | ||
return aggregateResponses(allResponses); | ||
} | ||
allResponses.push(value); | ||
} | ||
} | ||
/** | ||
* Reads a raw stream from the fetch response and join incomplete | ||
* chunks, returning a new stream that provides a single complete | ||
* GenerateContentResponse in each iteration. | ||
* @ignore | ||
*/ | ||
function getResponseStream(inputStream) { | ||
const reader = inputStream.getReader(); | ||
const stream = new ReadableStream({ | ||
start(controller) { | ||
let currentText = ''; | ||
return pump(); | ||
function pump() { | ||
return reader.read().then(({ value, done }) => { | ||
if (done) { | ||
if (currentText.trim()) { | ||
controller.error(new Error('Failed to parse stream')); | ||
return; | ||
} | ||
controller.close(); | ||
return; | ||
} | ||
currentText += value; | ||
let match = currentText.match(responseLineRE); | ||
let parsedResponse; | ||
while (match) { | ||
try { | ||
parsedResponse = JSON.parse(match[1]); | ||
} | ||
catch (e) { | ||
controller.error(new Error(`Error parsing JSON response: "${match[1]}"`)); | ||
return; | ||
} | ||
controller.enqueue(parsedResponse); | ||
currentText = currentText.substring(match[0].length); | ||
match = currentText.match(responseLineRE); | ||
} | ||
return pump(); | ||
}); | ||
} | ||
}, | ||
}); | ||
return stream; | ||
} | ||
exports.getResponseStream = getResponseStream; | ||
/** | ||
* Aggregates an array of `GenerateContentResponse`s into a single | ||
* GenerateContentResponse. | ||
* @ignore | ||
*/ | ||
function aggregateResponses(responses) { | ||
var _a, _b; | ||
const lastResponse = responses[responses.length - 1]; | ||
if (lastResponse === undefined) { | ||
throw new Error('Error processing stream because the response is undefined'); | ||
} | ||
const aggregatedResponse = { | ||
candidates: [], | ||
promptFeedback: lastResponse.promptFeedback, | ||
}; | ||
for (const response of responses) { | ||
for (let i = 0; i < response.candidates.length; i++) { | ||
if (!aggregatedResponse.candidates[i]) { | ||
aggregatedResponse.candidates[i] = { | ||
index: response.candidates[i].index, | ||
content: { | ||
role: response.candidates[i].content.role, | ||
parts: [{ text: '' }], | ||
}, | ||
}; | ||
} | ||
if (response.candidates[i].citationMetadata) { | ||
if (!((_a = aggregatedResponse.candidates[i].citationMetadata) === null || _a === void 0 ? void 0 : _a.citationSources)) { | ||
aggregatedResponse.candidates[i].citationMetadata = { | ||
citationSources: [], | ||
}; | ||
} | ||
const existingMetadata = (_b = response.candidates[i].citationMetadata) !== null && _b !== void 0 ? _b : {}; | ||
if (aggregatedResponse.candidates[i].citationMetadata) { | ||
aggregatedResponse.candidates[i].citationMetadata.citationSources = | ||
aggregatedResponse.candidates[i].citationMetadata.citationSources.concat(existingMetadata); | ||
} | ||
} | ||
aggregatedResponse.candidates[i].finishReason = | ||
response.candidates[i].finishReason; | ||
aggregatedResponse.candidates[i].finishMessage = | ||
response.candidates[i].finishMessage; | ||
aggregatedResponse.candidates[i].safetyRatings = | ||
response.candidates[i].safetyRatings; | ||
if ('parts' in response.candidates[i].content) { | ||
for (const part of response.candidates[i].content.parts) { | ||
if (part.text) { | ||
aggregatedResponse.candidates[i].content.parts[0].text += part.text; | ||
} | ||
if (part.functionCall) { | ||
aggregatedResponse.candidates[i].content.parts[0].functionCall = | ||
part.functionCall; | ||
// the empty 'text' key should be removed if functionCall is in the | ||
// response | ||
delete aggregatedResponse.candidates[i].content.parts[0].text; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
aggregatedResponse.promptFeedback = | ||
responses[responses.length - 1].promptFeedback; | ||
return aggregatedResponse; | ||
} | ||
/** | ||
* Process model responses from generateContent | ||
* @ignore | ||
*/ | ||
async function processNonStream(response) { | ||
if (response !== undefined) { | ||
// ts-ignore | ||
const responseJson = await response.json(); | ||
return Promise.resolve({ | ||
response: responseJson, | ||
}); | ||
} | ||
return Promise.resolve({ | ||
response: { candidates: [] }, | ||
}); | ||
} | ||
exports.processNonStream = processNonStream; | ||
/** | ||
* Process model responses from countTokens | ||
* @ignore | ||
*/ | ||
function processCountTokenResponse(response) { | ||
async function processCountTokenResponse(response) { | ||
// ts-ignore | ||
const responseJson = response.json(); | ||
return responseJson; | ||
return response.json(); | ||
} | ||
exports.processCountTokenResponse = processCountTokenResponse; | ||
//# sourceMappingURL=post_fetch_processing.js.map |
@@ -17,1 +17,5 @@ /** | ||
*/ | ||
import { GenerateContentRequest, GenerationConfig, SafetySetting } from '../types/content'; | ||
export declare function formatContentRequest(request: GenerateContentRequest | string, generation_config?: GenerationConfig, safety_settings?: SafetySetting[]): GenerateContentRequest; | ||
export declare function validateGenerateContentRequest(request: GenerateContentRequest): void; | ||
export declare function validateGenerationConfig(generation_config: GenerationConfig): GenerationConfig; |
@@ -18,2 +18,60 @@ "use strict"; | ||
*/ | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.validateGenerationConfig = exports.validateGenerateContentRequest = exports.formatContentRequest = void 0; | ||
const errors_1 = require("../types/errors"); | ||
const constants = require("../util/constants"); | ||
function formatContentRequest(request, generation_config, safety_settings) { | ||
if (typeof request === 'string') { | ||
return { | ||
contents: [{ role: constants.USER_ROLE, parts: [{ text: request }] }], | ||
generation_config: generation_config, | ||
safety_settings: safety_settings, | ||
}; | ||
} | ||
else { | ||
return request; | ||
} | ||
} | ||
exports.formatContentRequest = formatContentRequest; | ||
function validateGenerateContentRequest(request) { | ||
validateGcsInput(request.contents); | ||
validateFunctionResponseRequest(request.contents); | ||
} | ||
exports.validateGenerateContentRequest = validateGenerateContentRequest; | ||
function validateGenerationConfig(generation_config) { | ||
if ('top_k' in generation_config) { | ||
if (!(generation_config.top_k > 0) || !(generation_config.top_k <= 40)) { | ||
delete generation_config.top_k; | ||
} | ||
} | ||
return generation_config; | ||
} | ||
exports.validateGenerationConfig = validateGenerationConfig; | ||
function validateGcsInput(contents) { | ||
for (const content of contents) { | ||
for (const part of content.parts) { | ||
if ('file_data' in part) { | ||
// @ts-ignore | ||
const uri = part['file_data']['file_uri']; | ||
if (!uri.startsWith('gs://')) { | ||
throw new URIError(`Found invalid Google Cloud Storage URI ${uri}, Google Cloud Storage URIs must start with gs://`); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
function validateFunctionResponseRequest(contents) { | ||
const lastestContentPart = contents[contents.length - 1].parts[0]; | ||
if (!('functionResponse' in lastestContentPart)) { | ||
return; | ||
} | ||
const errorMessage = 'Please ensure that function response turn comes immediately after a function call turn.'; | ||
if (contents.length < 2) { | ||
throw new errors_1.ClientError(errorMessage); | ||
} | ||
const secondLastestContentPart = contents[contents.length - 2].parts[0]; | ||
if (!('functionCall' in secondLastestContentPart)) { | ||
throw new errors_1.ClientError(errorMessage); | ||
} | ||
} | ||
//# sourceMappingURL=pre_fetch_processing.js.map |
@@ -17,2 +17,7 @@ /** | ||
*/ | ||
export {}; | ||
import { GenerateContentResponse } from '../../types'; | ||
/** | ||
* Returns a generator, used to mock the generateContentStream response | ||
* @ignore | ||
*/ | ||
export declare function testGenerator(): AsyncGenerator<GenerateContentResponse>; |
@@ -19,13 +19,185 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.testGenerator = void 0; | ||
const types_1 = require("../../types"); | ||
const util_1 = require("../../util"); | ||
const count_tokens_1 = require("../count_tokens"); | ||
const util_1 = require("../../util"); | ||
const generate_content_1 = require("../generate_content"); | ||
const StreamFunctions = require("../post_fetch_processing"); | ||
const TEST_PROJECT = 'test-project'; | ||
const TEST_LOCATION = 'test-location'; | ||
const TEST_PUBLISHER_MODEL_ENDPOINT = 'test-publisher-model-endpoint'; | ||
const TEST_TOKEN_PROMISE = Promise.resolve('test-token'); | ||
const TEST_TOKEN = 'testtoken'; | ||
const TEST_TOKEN_PROMISE = Promise.resolve(TEST_TOKEN); | ||
const TEST_API_ENDPOINT = 'test-api-endpoint'; | ||
const TEST_CHAT_MESSSAGE_TEXT = 'How are you doing today?'; | ||
const TEST_CHAT_MESSAGE_TEXT = 'How are you doing today?'; | ||
const TEST_USER_CHAT_MESSAGE = [ | ||
{ role: util_1.constants.USER_ROLE, parts: [{ text: TEST_CHAT_MESSSAGE_TEXT }] }, | ||
{ role: util_1.constants.USER_ROLE, parts: [{ text: TEST_CHAT_MESSAGE_TEXT }] }, | ||
]; | ||
const TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE = [ | ||
{ | ||
role: util_1.constants.USER_ROLE, | ||
parts: [ | ||
{ text: TEST_CHAT_MESSAGE_TEXT }, | ||
{ | ||
file_data: { | ||
file_uri: 'gs://test_bucket/test_image.jpeg', | ||
mime_type: 'image/jpeg', | ||
}, | ||
}, | ||
], | ||
}, | ||
]; | ||
const TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE = [ | ||
{ | ||
role: util_1.constants.USER_ROLE, | ||
parts: [ | ||
{ text: TEST_CHAT_MESSAGE_TEXT }, | ||
{ file_data: { file_uri: 'test_image.jpeg', mime_type: 'image/jpeg' } }, | ||
], | ||
}, | ||
]; | ||
const TEST_SAFETY_SETTINGS = [ | ||
{ | ||
category: types_1.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | ||
threshold: types_1.HarmBlockThreshold.BLOCK_ONLY_HIGH, | ||
}, | ||
]; | ||
const TEST_SAFETY_RATINGS = [ | ||
{ | ||
category: types_1.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | ||
probability: types_1.HarmProbability.NEGLIGIBLE, | ||
}, | ||
]; | ||
const TEST_GENERATION_CONFIG = { | ||
candidate_count: 1, | ||
stop_sequences: ['hello'], | ||
}; | ||
const TEST_CANDIDATES = [ | ||
{ | ||
index: 1, | ||
content: { | ||
role: util_1.constants.MODEL_ROLE, | ||
parts: [{ text: 'Im doing great! How are you?' }], | ||
}, | ||
finishReason: types_1.FinishReason.STOP, | ||
finishMessage: '', | ||
safetyRatings: TEST_SAFETY_RATINGS, | ||
citationMetadata: { | ||
citationSources: [ | ||
{ | ||
startIndex: 367, | ||
endIndex: 491, | ||
uri: 'https://www.numerade.com/ask/question/why-does-the-uncertainty-principle-make-it-impossible-to-predict-a-trajectory-for-the-clectron-95172/', | ||
}, | ||
], | ||
}, | ||
}, | ||
]; | ||
const TEST_MODEL_RESPONSE = { | ||
candidates: TEST_CANDIDATES, | ||
usage_metadata: { prompt_token_count: 0, candidates_token_count: 0 }, | ||
}; | ||
const TEST_FUNCTION_CALL_RESPONSE = { | ||
functionCall: { | ||
name: 'get_current_weather', | ||
args: { | ||
location: 'LA', | ||
unit: 'fahrenheit', | ||
}, | ||
}, | ||
}; | ||
const TEST_CANDIDATES_WITH_FUNCTION_CALL = [ | ||
{ | ||
index: 1, | ||
content: { | ||
role: util_1.constants.MODEL_ROLE, | ||
parts: [TEST_FUNCTION_CALL_RESPONSE], | ||
}, | ||
finishReason: types_1.FinishReason.STOP, | ||
finishMessage: '', | ||
safetyRatings: TEST_SAFETY_RATINGS, | ||
}, | ||
]; | ||
const TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL = { | ||
candidates: TEST_CANDIDATES_WITH_FUNCTION_CALL, | ||
}; | ||
const TEST_FUNCTION_RESPONSE_PART = [ | ||
{ | ||
functionResponse: { | ||
name: 'get_current_weather', | ||
response: { name: 'get_current_weather', content: { weather: 'super nice' } }, | ||
}, | ||
}, | ||
]; | ||
const TEST_CANDIDATES_MISSING_ROLE = [ | ||
{ | ||
index: 1, | ||
content: { parts: [{ text: 'Im doing great! How are you?' }] }, | ||
finish_reason: 0, | ||
finish_message: '', | ||
safety_ratings: TEST_SAFETY_RATINGS, | ||
}, | ||
]; | ||
const TEST_ENDPOINT_BASE_PATH = 'test.googleapis.com'; | ||
const TEST_GCS_FILENAME = 'gs://test_bucket/test_image.jpeg'; | ||
const TEST_MULTIPART_MESSAGE = [ | ||
{ | ||
role: util_1.constants.USER_ROLE, | ||
parts: [ | ||
{ text: 'What is in this picture?' }, | ||
{ file_data: { file_uri: TEST_GCS_FILENAME, mime_type: 'image/jpeg' } }, | ||
], | ||
}, | ||
]; | ||
const BASE_64_IMAGE = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='; | ||
const INLINE_DATA_FILE_PART = { | ||
inline_data: { | ||
data: BASE_64_IMAGE, | ||
mime_type: 'image/jpeg', | ||
}, | ||
}; | ||
const TEST_MULTIPART_MESSAGE_BASE64 = [ | ||
{ | ||
role: util_1.constants.USER_ROLE, | ||
parts: [{ text: 'What is in this picture?' }, INLINE_DATA_FILE_PART], | ||
}, | ||
]; | ||
const TEST_TOOLS_WITH_FUNCTION_DECLARATION = [ | ||
{ | ||
function_declarations: [ | ||
{ | ||
name: 'get_current_weather', | ||
description: 'get weather in a given location', | ||
parameters: { | ||
type: types_1.FunctionDeclarationSchemaType.OBJECT, | ||
properties: { | ||
location: { type: types_1.FunctionDeclarationSchemaType.STRING }, | ||
unit: { | ||
type: types_1.FunctionDeclarationSchemaType.STRING, | ||
enum: ['celsius', 'fahrenheit'], | ||
}, | ||
}, | ||
required: ['location'], | ||
}, | ||
}, | ||
], | ||
}, | ||
]; | ||
const fetchResponseObj = { | ||
status: 200, | ||
statusText: 'OK', | ||
ok: true, | ||
headers: { 'Content-Type': 'application/json' }, | ||
url: 'url', | ||
}; | ||
/** | ||
* Returns a generator, used to mock the generateContentStream response | ||
* @ignore | ||
*/ | ||
async function* testGenerator() { | ||
yield { | ||
candidates: TEST_CANDIDATES, | ||
}; | ||
} | ||
exports.testGenerator = testGenerator; | ||
describe('countTokens', () => { | ||
@@ -36,9 +208,2 @@ const req = { | ||
it('return expected response when OK', async () => { | ||
const fetchResponseObj = { | ||
status: 200, | ||
statusText: 'OK', | ||
ok: true, | ||
headers: { 'Content-Type': 'application/json' }, | ||
url: 'url', | ||
}; | ||
const expectedResponseBody = { | ||
@@ -49,3 +214,3 @@ totalTokens: 1, | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
const resp = await (0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, TEST_API_ENDPOINT, req); | ||
const resp = await (0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResponseBody); | ||
@@ -59,10 +224,22 @@ }); | ||
}; | ||
const body = {}; | ||
const body = { | ||
code: 500, | ||
message: 'service is having downtime', | ||
status: 'INTERNAL_SERVER_ERROR', | ||
}; | ||
const response = new Response(JSON.stringify(body), fetch500Obj); | ||
const expectedErrorMessage = '[VertexAI.GoogleGenerativeAIError]: got status: 500 Internal Server Error'; | ||
const expectedErrorMessage = '[VertexAI.GoogleGenerativeAIError]: got status: 500 Internal Server Error. {"code":500,"message":"service is having downtime","status":"INTERNAL_SERVER_ERROR"}'; | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
await expectAsync((0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, TEST_API_ENDPOINT, req)).toBeRejected(); | ||
await (0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, TEST_API_ENDPOINT, req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
await expectAsync((0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT)).toBeRejected(); | ||
// TODO: update jasmine version or use flush to uncomment | ||
// await countTokens( | ||
// TEST_LOCATION, | ||
// TEST_PROJECT, | ||
// TEST_PUBLISHER_MODEL_ENDPOINT, | ||
// TEST_TOKEN_PROMISE, | ||
// req, | ||
// TEST_API_ENDPOINT | ||
// ).catch(e => { | ||
// expect(e.message).toEqual(expectedErrorMessage); | ||
// }); | ||
}); | ||
@@ -75,12 +252,302 @@ it('throw ClientError when not OK and 4XX', async () => { | ||
}; | ||
const body = {}; | ||
const body = { | ||
code: 400, | ||
message: 'request is invalid', | ||
status: 'INVALID_ARGUMENT', | ||
}; | ||
const response = new Response(JSON.stringify(body), fetch400Obj); | ||
const expectedErrorMessage = '[VertexAI.ClientError]: got status: 400 Bad Request'; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: got status: 400 Bad Request. {"code":400,"message":"request is invalid","status":"INVALID_ARGUMENT"}'; | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
await expectAsync((0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, TEST_API_ENDPOINT, req)).toBeRejected(); | ||
await (0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, TEST_API_ENDPOINT, req).catch(e => { | ||
await expectAsync((0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT)).toBeRejected(); | ||
// TODO: update jasmine version or use flush to uncomment | ||
// await countTokens( | ||
// TEST_LOCATION, | ||
// TEST_PROJECT, | ||
// TEST_PUBLISHER_MODEL_ENDPOINT, | ||
// TEST_TOKEN_PROMISE, | ||
// req, | ||
// TEST_API_ENDPOINT | ||
// ).catch(e => { | ||
// expect(e.message).toEqual(expectedErrorMessage); | ||
// }); | ||
}); | ||
}); | ||
describe('generateContent', () => { | ||
let expectedStreamResult; | ||
let fetchSpy; | ||
beforeEach(() => { | ||
expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
const fetchResult = new Response(JSON.stringify(expectedStreamResult), fetchResponseObj); | ||
fetchSpy = spyOn(global, 'fetch').and.resolveTo(fetchResult); | ||
}); | ||
it('returns a GenerateContentResponse', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a string', async () => { | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, TEST_CHAT_MESSAGE_TEXT, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a GCS URI', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('raises an error when passed an invalid GCS URI', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE, | ||
}; | ||
await expectAsync((0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT)).toBeRejectedWithError(URIError); | ||
}); | ||
it('returns a GenerateContentResponse when passed safety_settings and generation_config', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
safety_settings: TEST_SAFETY_SETTINGS, | ||
generation_config: TEST_GENERATION_CONFIG, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('updates the base API endpoint when provided', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_ENDPOINT_BASE_PATH); | ||
expect(fetchSpy.calls.allArgs()[0][0].toString()).toContain(TEST_ENDPOINT_BASE_PATH); | ||
}); | ||
it('removes top_k when it is set to 0', async () => { | ||
const reqWithEmptyConfigs = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
generation_config: { top_k: 0 }, | ||
safety_settings: [], | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, reqWithEmptyConfigs, TEST_API_ENDPOINT); | ||
const requestArgs = fetchSpy.calls.allArgs()[0][1]; | ||
if (typeof requestArgs === 'object' && requestArgs) { | ||
expect(JSON.stringify(requestArgs['body'])).not.toContain('top_k'); | ||
} | ||
}); | ||
it('includes top_k when it is within 1 - 40', async () => { | ||
const reqWithEmptyConfigs = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
generation_config: { top_k: 1 }, | ||
safety_settings: [], | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, reqWithEmptyConfigs, TEST_API_ENDPOINT); | ||
const requestArgs = fetchSpy.calls.allArgs()[0][1]; | ||
if (typeof requestArgs === 'object' && requestArgs) { | ||
expect(JSON.stringify(requestArgs['body'])).toContain('top_k'); | ||
} | ||
}); | ||
it('aggregates citation metadata', async () => { | ||
var _a; | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect((_a = resp.response.candidates[0].citationMetadata) === null || _a === void 0 ? void 0 : _a.citationSources.length).toEqual(TEST_MODEL_RESPONSE.candidates[0].citationMetadata.citationSources.length); | ||
}); | ||
it('returns a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const req = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] }, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => { | ||
const req = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{ text: 'What is the weater like in Boston?' }], | ||
}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => { | ||
const req = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{ text: 'What is the weater like in Boston?' }], | ||
}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('generateContentStream', () => { | ||
let expectedStreamResult; | ||
let fetchSpy; | ||
beforeEach(() => { | ||
expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
const fetchResult = new Response(JSON.stringify(expectedStreamResult), fetchResponseObj); | ||
fetchSpy = spyOn(global, 'fetch').and.resolveTo(fetchResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed text content', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a string', async () => { | ||
const expectedResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, TEST_API_ENDPOINT, TEST_CHAT_MESSAGE_TEXT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed multi-part content with a GCS URI', async () => { | ||
const req = { | ||
contents: TEST_MULTIPART_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed multi-part content with base64 data', async () => { | ||
const req = { | ||
contents: TEST_MULTIPART_MESSAGE_BASE64, | ||
}; | ||
const expectedResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const req = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] }, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedStreamResult); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedStreamResult); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => { | ||
const req = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{ text: 'What is the weater like in Boston?' }], | ||
}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => { | ||
const req = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{ text: 'What is the weater like in Boston?' }], | ||
}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
//# sourceMappingURL=functions_test.js.map |
@@ -17,177 +17,3 @@ /** | ||
*/ | ||
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; | ||
import { Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, GenerateContentResult, GenerationConfig, ModelParams, Part, SafetySetting, StreamGenerateContentResult, Tool, VertexInit } from './types/content'; | ||
export { VertexAI } from './vertex_ai'; | ||
export * from './types'; | ||
/** | ||
* Base class for authenticating to Vertex, creates the preview namespace. | ||
*/ | ||
export declare class VertexAI { | ||
preview: VertexAI_Preview; | ||
/** | ||
* @constructor | ||
* @param {VertexInit} init - assign authentication related information, | ||
* including project and location string, to instantiate a Vertex AI | ||
* client. | ||
*/ | ||
constructor(init: VertexInit); | ||
} | ||
/** | ||
* VertexAI class internal implementation for authentication. | ||
*/ | ||
export declare class VertexAI_Preview { | ||
readonly project: string; | ||
readonly location: string; | ||
readonly apiEndpoint?: string | undefined; | ||
readonly googleAuthOptions?: GoogleAuthOptions<import("google-auth-library/build/src/auth/googleauth").JSONClient> | undefined; | ||
protected googleAuth: GoogleAuth; | ||
/** | ||
* @constructor | ||
* @param {string} - project The Google Cloud project to use for the request | ||
* @param {string} - location The Google Cloud project location to use for the request | ||
* @param {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
* @param {GoogleAuthOptions} - [googleAuthOptions] The Authentication options provided by google-auth-library. | ||
* Complete list of authentication options is documented in the GoogleAuthOptions interface: | ||
* https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
*/ | ||
constructor(project: string, location: string, apiEndpoint?: string | undefined, googleAuthOptions?: GoogleAuthOptions<import("google-auth-library/build/src/auth/googleauth").JSONClient> | undefined); | ||
/** | ||
* @param {ModelParams} modelParams - {@link ModelParams} Parameters to specify the generative model. | ||
* @return {GenerativeModel} Instance of the GenerativeModel class. {@link GenerativeModel} | ||
*/ | ||
getGenerativeModel(modelParams: ModelParams): GenerativeModel; | ||
validateGoogleAuthOptions(project: string, googleAuthOptions?: GoogleAuthOptions): GoogleAuthOptions; | ||
} | ||
/** | ||
* Params to initiate a multiturn chat with the model via startChat | ||
* @property {Content[]} - [history] history of the chat session. {@link Content} | ||
* @property {SafetySetting[]} - [safety_settings] Array of {@link SafetySetting} | ||
* @property {GenerationConfig} - [generation_config] {@link GenerationConfig} | ||
*/ | ||
export declare interface StartChatParams { | ||
history?: Content[]; | ||
safety_settings?: SafetySetting[]; | ||
generation_config?: GenerationConfig; | ||
tools?: Tool[]; | ||
} | ||
/** | ||
* All params passed to initiate multiturn chat via startChat | ||
* @property {VertexAI_Preview} - _vertex_instance {@link VertexAI_Preview} | ||
* @property {GenerativeModel} - _model_instance {@link GenerativeModel} | ||
*/ | ||
export declare interface StartChatSessionRequest extends StartChatParams { | ||
project: string; | ||
location: string; | ||
_model_instance: GenerativeModel; | ||
} | ||
/** | ||
* @property {string} model - model name | ||
* @property {string} project - project The Google Cloud project to use for the request | ||
* @property {string} location - The Google Cloud project location to use for the request | ||
* @property {GoogleAuth} googleAuth - GoogleAuth class instance that handles authentication. | ||
* Details about GoogleAuth is referred to https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
* @property {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
* @property {GenerationConfig} [generation_config] - {@link | ||
* GenerationConfig} | ||
* @property {SafetySetting[]} [safety_settings] - {@link SafetySetting} | ||
* @property {Tool[]} [tools] - {@link Tool} | ||
*/ | ||
export declare interface GetGenerativeModelParams extends ModelParams { | ||
model: string; | ||
project: string; | ||
location: string; | ||
googleAuth: GoogleAuth; | ||
apiEndpoint?: string; | ||
generation_config?: GenerationConfig; | ||
safety_settings?: SafetySetting[]; | ||
tools?: Tool[]; | ||
} | ||
/** | ||
* Chat session to make multi-turn send message request. | ||
* `sendMessage` method makes async call to get response of a chat message. | ||
* `sendMessageStream` method makes async call to stream response of a chat message. | ||
*/ | ||
export declare class ChatSession { | ||
private project; | ||
private location; | ||
private historyInternal; | ||
private _model_instance; | ||
private _send_stream_promise; | ||
generation_config?: GenerationConfig; | ||
safety_settings?: SafetySetting[]; | ||
tools?: Tool[]; | ||
get history(): Content[]; | ||
/** | ||
* @constructor | ||
* @param {StartChatSessionRequest} request - {@link StartChatSessionRequest} | ||
*/ | ||
constructor(request: StartChatSessionRequest); | ||
/** | ||
* Make an sync call to send message. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<GenerateContentResult>} Promise of {@link GenerateContentResult} | ||
*/ | ||
sendMessage(request: string | Array<string | Part>): Promise<GenerateContentResult>; | ||
appendHistory(streamGenerateContentResultPromise: Promise<StreamGenerateContentResult>, newContent: Content[]): Promise<void>; | ||
/** | ||
* Make an async call to stream send message. Response will be returned in stream. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
sendMessageStream(request: string | Array<string | Part>): Promise<StreamGenerateContentResult>; | ||
} | ||
/** | ||
* Base class for generative models. | ||
* NOTE: this class should not be instantiated directly. Use | ||
* `vertexai.preview.getGenerativeModel()` instead. | ||
*/ | ||
export declare class GenerativeModel { | ||
model: string; | ||
generation_config?: GenerationConfig; | ||
safety_settings?: SafetySetting[]; | ||
tools?: Tool[]; | ||
private project; | ||
private location; | ||
private googleAuth; | ||
private publisherModelEndpoint; | ||
private apiEndpoint?; | ||
/** | ||
* @constructor | ||
* @param {GetGenerativeModelParams} getGenerativeModelParams - {@link GetGenerativeModelParams} | ||
*/ | ||
constructor(getGenerativeModelParams: GetGenerativeModelParams); | ||
/** | ||
* Get access token from GoogleAuth. Throws GoogleAuthError when fails. | ||
* @return {Promise<any>} Promise of token | ||
*/ | ||
get token(): Promise<any>; | ||
/** | ||
* Make a async call to generate content. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
* @return The GenerateContentResponse object with the response candidates. | ||
*/ | ||
generateContent(request: GenerateContentRequest | string): Promise<GenerateContentResult>; | ||
/** | ||
* Make an async stream request to generate content. The response will be returned in stream. | ||
* @param {GenerateContentRequest} request - {@link GenerateContentRequest} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
generateContentStream(request: GenerateContentRequest | string): Promise<StreamGenerateContentResult>; | ||
/** | ||
* Make a async request to count tokens. | ||
* @param request A CountTokensRequest object with the request contents. | ||
* @return The CountTokensResponse object with the token count. | ||
*/ | ||
countTokens(request: CountTokensRequest): Promise<CountTokensResponse>; | ||
/** | ||
* Instantiate a ChatSession. | ||
* This method doesn't make any call to remote endpoint. | ||
* Any call to remote endpoint is implemented in ChatSession class @see ChatSession | ||
* @param{StartChatParams} [request] - {@link StartChatParams} | ||
* @return {ChatSession} {@link ChatSession} | ||
*/ | ||
startChat(request?: StartChatParams): ChatSession; | ||
} |
@@ -33,470 +33,6 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.GenerativeModel = exports.ChatSession = exports.VertexAI_Preview = exports.VertexAI = void 0; | ||
/* tslint:disable */ | ||
const google_auth_library_1 = require("google-auth-library"); | ||
const process_stream_1 = require("./process_stream"); | ||
const errors_1 = require("./types/errors"); | ||
const util_1 = require("./util"); | ||
exports.VertexAI = void 0; | ||
var vertex_ai_1 = require("./vertex_ai"); | ||
Object.defineProperty(exports, "VertexAI", { enumerable: true, get: function () { return vertex_ai_1.VertexAI; } }); | ||
__exportStar(require("./types"), exports); | ||
/** | ||
* Base class for authenticating to Vertex, creates the preview namespace. | ||
*/ | ||
class VertexAI { | ||
/** | ||
* @constructor | ||
* @param {VertexInit} init - assign authentication related information, | ||
* including project and location string, to instantiate a Vertex AI | ||
* client. | ||
*/ | ||
constructor(init) { | ||
/** | ||
* preview property is used to access any SDK methods available in public | ||
* preview, currently all functionality. | ||
*/ | ||
this.preview = new VertexAI_Preview(init.project, init.location, init.apiEndpoint, init.googleAuthOptions); | ||
} | ||
} | ||
exports.VertexAI = VertexAI; | ||
/** | ||
* VertexAI class internal implementation for authentication. | ||
*/ | ||
class VertexAI_Preview { | ||
/** | ||
* @constructor | ||
* @param {string} - project The Google Cloud project to use for the request | ||
* @param {string} - location The Google Cloud project location to use for the request | ||
* @param {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
* @param {GoogleAuthOptions} - [googleAuthOptions] The Authentication options provided by google-auth-library. | ||
* Complete list of authentication options is documented in the GoogleAuthOptions interface: | ||
* https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
*/ | ||
constructor(project, location, apiEndpoint, googleAuthOptions) { | ||
this.project = project; | ||
this.location = location; | ||
this.apiEndpoint = apiEndpoint; | ||
this.googleAuthOptions = googleAuthOptions; | ||
const opts = this.validateGoogleAuthOptions(project, googleAuthOptions); | ||
this.project = project; | ||
this.location = location; | ||
this.apiEndpoint = apiEndpoint; | ||
this.googleAuth = new google_auth_library_1.GoogleAuth(opts); | ||
} | ||
/** | ||
* @param {ModelParams} modelParams - {@link ModelParams} Parameters to specify the generative model. | ||
* @return {GenerativeModel} Instance of the GenerativeModel class. {@link GenerativeModel} | ||
*/ | ||
getGenerativeModel(modelParams) { | ||
const getGenerativeModelParams = { | ||
model: modelParams.model, | ||
project: this.project, | ||
location: this.location, | ||
googleAuth: this.googleAuth, | ||
apiEndpoint: this.apiEndpoint, | ||
safety_settings: modelParams.safety_settings, | ||
tools: modelParams.tools, | ||
}; | ||
if (modelParams.generation_config) { | ||
getGenerativeModelParams.generation_config = validateGenerationConfig(modelParams.generation_config); | ||
} | ||
return new GenerativeModel(getGenerativeModelParams); | ||
} | ||
validateGoogleAuthOptions(project, googleAuthOptions) { | ||
let opts; | ||
const requiredScope = 'https://www.googleapis.com/auth/cloud-platform'; | ||
if (!googleAuthOptions) { | ||
opts = { | ||
scopes: requiredScope, | ||
}; | ||
return opts; | ||
} | ||
if (googleAuthOptions.projectId && | ||
googleAuthOptions.projectId !== project) { | ||
throw new Error(`inconsistent project ID values. argument project got value ${project} but googleAuthOptions.projectId got value ${googleAuthOptions.projectId}`); | ||
} | ||
opts = googleAuthOptions; | ||
if (!opts.scopes) { | ||
opts.scopes = requiredScope; | ||
return opts; | ||
} | ||
if ((typeof opts.scopes === 'string' && opts.scopes !== requiredScope) || | ||
(Array.isArray(opts.scopes) && opts.scopes.indexOf(requiredScope) < 0)) { | ||
throw new errors_1.GoogleAuthError(`input GoogleAuthOptions.scopes ${opts.scopes} doesn't contain required scope ${requiredScope}, please include ${requiredScope} into GoogleAuthOptions.scopes or leave GoogleAuthOptions.scopes undefined`); | ||
} | ||
return opts; | ||
} | ||
} | ||
exports.VertexAI_Preview = VertexAI_Preview; | ||
/** | ||
* Chat session to make multi-turn send message request. | ||
* `sendMessage` method makes async call to get response of a chat message. | ||
* `sendMessageStream` method makes async call to stream response of a chat message. | ||
*/ | ||
class ChatSession { | ||
get history() { | ||
return this.historyInternal; | ||
} | ||
/** | ||
* @constructor | ||
* @param {StartChatSessionRequest} request - {@link StartChatSessionRequest} | ||
*/ | ||
constructor(request) { | ||
var _a; | ||
this._send_stream_promise = Promise.resolve(); | ||
this.project = request.project; | ||
this.location = request.location; | ||
this._model_instance = request._model_instance; | ||
this.historyInternal = (_a = request.history) !== null && _a !== void 0 ? _a : []; | ||
this.generation_config = request.generation_config; | ||
this.safety_settings = request.safety_settings; | ||
this.tools = request.tools; | ||
} | ||
/** | ||
* Make an sync call to send message. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<GenerateContentResult>} Promise of {@link GenerateContentResult} | ||
*/ | ||
async sendMessage(request) { | ||
const newContent = formulateNewContentFromSendMessageRequest(request); | ||
const generateContentrequest = { | ||
contents: this.historyInternal.concat(newContent), | ||
safety_settings: this.safety_settings, | ||
generation_config: this.generation_config, | ||
tools: this.tools, | ||
}; | ||
const generateContentResult = await this._model_instance | ||
.generateContent(generateContentrequest) | ||
.catch(e => { | ||
throw e; | ||
}); | ||
const generateContentResponse = await generateContentResult.response; | ||
// Only push the latest message to history if the response returned a result | ||
if (generateContentResponse.candidates.length !== 0) { | ||
this.historyInternal = this.historyInternal.concat(newContent); | ||
const contentFromAssistant = generateContentResponse.candidates[0].content; | ||
if (!contentFromAssistant.role) { | ||
contentFromAssistant.role = util_1.constants.MODEL_ROLE; | ||
} | ||
this.historyInternal.push(contentFromAssistant); | ||
} | ||
else { | ||
// TODO: handle promptFeedback in the response | ||
throw new Error('Did not get a candidate from the model'); | ||
} | ||
return Promise.resolve({ response: generateContentResponse }); | ||
} | ||
async appendHistory(streamGenerateContentResultPromise, newContent) { | ||
const streamGenerateContentResult = await streamGenerateContentResultPromise; | ||
const streamGenerateContentResponse = await streamGenerateContentResult.response; | ||
// Only push the latest message to history if the response returned a result | ||
if (streamGenerateContentResponse.candidates.length !== 0) { | ||
this.historyInternal = this.historyInternal.concat(newContent); | ||
const contentFromAssistant = streamGenerateContentResponse.candidates[0].content; | ||
if (!contentFromAssistant.role) { | ||
contentFromAssistant.role = util_1.constants.MODEL_ROLE; | ||
} | ||
this.historyInternal.push(contentFromAssistant); | ||
} | ||
else { | ||
// TODO: handle promptFeedback in the response | ||
throw new Error('Did not get a candidate from the model'); | ||
} | ||
} | ||
/** | ||
* Make an async call to stream send message. Response will be returned in stream. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
async sendMessageStream(request) { | ||
const newContent = formulateNewContentFromSendMessageRequest(request); | ||
const generateContentrequest = { | ||
contents: this.historyInternal.concat(newContent), | ||
safety_settings: this.safety_settings, | ||
generation_config: this.generation_config, | ||
tools: this.tools, | ||
}; | ||
const streamGenerateContentResultPromise = this._model_instance | ||
.generateContentStream(generateContentrequest) | ||
.catch(e => { | ||
throw e; | ||
}); | ||
this._send_stream_promise = this.appendHistory(streamGenerateContentResultPromise, newContent).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception appending chat history', e); | ||
}); | ||
return streamGenerateContentResultPromise; | ||
} | ||
} | ||
exports.ChatSession = ChatSession; | ||
/** | ||
* Base class for generative models. | ||
* NOTE: this class should not be instantiated directly. Use | ||
* `vertexai.preview.getGenerativeModel()` instead. | ||
*/ | ||
class GenerativeModel { | ||
/** | ||
* @constructor | ||
* @param {GetGenerativeModelParams} getGenerativeModelParams - {@link GetGenerativeModelParams} | ||
*/ | ||
constructor(getGenerativeModelParams) { | ||
this.project = getGenerativeModelParams.project; | ||
this.location = getGenerativeModelParams.location; | ||
this.apiEndpoint = getGenerativeModelParams.apiEndpoint; | ||
this.googleAuth = getGenerativeModelParams.googleAuth; | ||
this.model = getGenerativeModelParams.model; | ||
this.generation_config = getGenerativeModelParams.generation_config; | ||
this.safety_settings = getGenerativeModelParams.safety_settings; | ||
this.tools = getGenerativeModelParams.tools; | ||
if (this.model.startsWith('models/')) { | ||
this.publisherModelEndpoint = `publishers/google/${this.model}`; | ||
} | ||
else { | ||
this.publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
} | ||
} | ||
/** | ||
* Get access token from GoogleAuth. Throws GoogleAuthError when fails. | ||
* @return {Promise<any>} Promise of token | ||
*/ | ||
get token() { | ||
const credential_error_message = '\nUnable to authenticate your request\ | ||
\nDepending on your run time environment, you can get authentication by\ | ||
\n- if in local instance or cloud shell: `!gcloud auth login`\ | ||
\n- if in Colab:\ | ||
\n -`from google.colab import auth`\ | ||
\n -`auth.authenticate_user()`\ | ||
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication'; | ||
const tokenPromise = this.googleAuth.getAccessToken().catch(e => { | ||
throw new errors_1.GoogleAuthError(credential_error_message, e); | ||
}); | ||
return tokenPromise; | ||
} | ||
/** | ||
* Make a async call to generate content. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
* @return The GenerateContentResponse object with the response candidates. | ||
*/ | ||
async generateContent(request) { | ||
var _a, _b, _c; | ||
request = formatContentRequest(request, this.generation_config, this.safety_settings); | ||
validateGenerateContentRequest(request); | ||
if (request.generation_config) { | ||
request.generation_config = validateGenerationConfig(request.generation_config); | ||
} | ||
const generateContentRequest = { | ||
contents: request.contents, | ||
generation_config: (_a = request.generation_config) !== null && _a !== void 0 ? _a : this.generation_config, | ||
safety_settings: (_b = request.safety_settings) !== null && _b !== void 0 ? _b : this.safety_settings, | ||
tools: (_c = request.tools) !== null && _c !== void 0 ? _c : [], | ||
}; | ||
const response = await (0, util_1.postRequest)({ | ||
region: this.location, | ||
project: this.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: util_1.constants.GENERATE_CONTENT_METHOD, | ||
token: await this.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this.apiEndpoint, | ||
}).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
const result = (0, process_stream_1.processNonStream)(response); | ||
return Promise.resolve(result); | ||
} | ||
/** | ||
* Make an async stream request to generate content. The response will be returned in stream. | ||
* @param {GenerateContentRequest} request - {@link GenerateContentRequest} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
async generateContentStream(request) { | ||
var _a, _b, _c; | ||
request = formatContentRequest(request, this.generation_config, this.safety_settings); | ||
validateGenerateContentRequest(request); | ||
if (request.generation_config) { | ||
request.generation_config = validateGenerationConfig(request.generation_config); | ||
} | ||
const generateContentRequest = { | ||
contents: request.contents, | ||
generation_config: (_a = request.generation_config) !== null && _a !== void 0 ? _a : this.generation_config, | ||
safety_settings: (_b = request.safety_settings) !== null && _b !== void 0 ? _b : this.safety_settings, | ||
tools: (_c = request.tools) !== null && _c !== void 0 ? _c : [], | ||
}; | ||
const response = await (0, util_1.postRequest)({ | ||
region: this.location, | ||
project: this.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: util_1.constants.STREAMING_GENERATE_CONTENT_METHOD, | ||
token: await this.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this.apiEndpoint, | ||
}).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
const streamResult = (0, process_stream_1.processStream)(response); | ||
return Promise.resolve(streamResult); | ||
} | ||
/** | ||
* Make a async request to count tokens. | ||
* @param request A CountTokensRequest object with the request contents. | ||
* @return The CountTokensResponse object with the token count. | ||
*/ | ||
async countTokens(request) { | ||
const response = await (0, util_1.postRequest)({ | ||
region: this.location, | ||
project: this.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: 'countTokens', | ||
token: await this.token, | ||
data: request, | ||
apiEndpoint: this.apiEndpoint, | ||
}).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
return (0, process_stream_1.processCountTokenResponse)(response); | ||
} | ||
/** | ||
* Instantiate a ChatSession. | ||
* This method doesn't make any call to remote endpoint. | ||
* Any call to remote endpoint is implemented in ChatSession class @see ChatSession | ||
* @param{StartChatParams} [request] - {@link StartChatParams} | ||
* @return {ChatSession} {@link ChatSession} | ||
*/ | ||
startChat(request) { | ||
var _a, _b, _c; | ||
const startChatRequest = { | ||
project: this.project, | ||
location: this.location, | ||
_model_instance: this, | ||
}; | ||
if (request) { | ||
startChatRequest.history = request.history; | ||
startChatRequest.generation_config = | ||
(_a = request.generation_config) !== null && _a !== void 0 ? _a : this.generation_config; | ||
startChatRequest.safety_settings = | ||
(_b = request.safety_settings) !== null && _b !== void 0 ? _b : this.safety_settings; | ||
startChatRequest.tools = (_c = request.tools) !== null && _c !== void 0 ? _c : this.tools; | ||
} | ||
return new ChatSession(startChatRequest); | ||
} | ||
} | ||
exports.GenerativeModel = GenerativeModel; | ||
function formulateNewContentFromSendMessageRequest(request) { | ||
let newParts = []; | ||
if (typeof request === 'string') { | ||
newParts = [{ text: request }]; | ||
} | ||
else if (Array.isArray(request)) { | ||
for (const item of request) { | ||
if (typeof item === 'string') { | ||
newParts.push({ text: item }); | ||
} | ||
else { | ||
newParts.push(item); | ||
} | ||
} | ||
} | ||
return assignRoleToPartsAndValidateSendMessageRequest(newParts); | ||
} | ||
/** | ||
* When multiple Part types (i.e. FunctionResponsePart and TextPart) are | ||
* passed in a single Part array, we may need to assign different roles to each | ||
* part. Currently only FunctionResponsePart requires a role other than 'user'. | ||
* @ignore | ||
* @param {Array<Part>} parts Array of parts to pass to the model | ||
* @return {Content[]} Array of content items | ||
*/ | ||
function assignRoleToPartsAndValidateSendMessageRequest(parts) { | ||
const userContent = { role: util_1.constants.USER_ROLE, parts: [] }; | ||
const functionContent = { role: util_1.constants.FUNCTION_ROLE, parts: [] }; | ||
let hasUserContent = false; | ||
let hasFunctionContent = false; | ||
for (const part of parts) { | ||
if ('functionResponse' in part) { | ||
functionContent.parts.push(part); | ||
hasFunctionContent = true; | ||
} | ||
else { | ||
userContent.parts.push(part); | ||
hasUserContent = true; | ||
} | ||
} | ||
if (hasUserContent && hasFunctionContent) { | ||
throw new errors_1.ClientError('Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.'); | ||
} | ||
if (!hasUserContent && !hasFunctionContent) { | ||
throw new errors_1.ClientError('No content is provided for sending chat message.'); | ||
} | ||
if (hasUserContent) { | ||
return [userContent]; | ||
} | ||
return [functionContent]; | ||
} | ||
function throwErrorIfNotOK(response) { | ||
if (response === undefined) { | ||
throw new errors_1.GoogleGenerativeAIError('response is undefined'); | ||
} | ||
const status = response.status; | ||
const statusText = response.statusText; | ||
const errorMessage = `got status: ${status} ${statusText}`; | ||
if (status >= 400 && status < 500) { | ||
throw new errors_1.ClientError(errorMessage); | ||
} | ||
if (!response.ok) { | ||
throw new errors_1.GoogleGenerativeAIError(errorMessage); | ||
} | ||
} | ||
function validateGcsInput(contents) { | ||
for (const content of contents) { | ||
for (const part of content.parts) { | ||
if ('file_data' in part) { | ||
// @ts-ignore | ||
const uri = part['file_data']['file_uri']; | ||
if (!uri.startsWith('gs://')) { | ||
throw new URIError(`Found invalid Google Cloud Storage URI ${uri}, Google Cloud Storage URIs must start with gs://`); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
function validateFunctionResponseRequest(contents) { | ||
const lastestContentPart = contents[contents.length - 1].parts[0]; | ||
if (!('functionResponse' in lastestContentPart)) { | ||
return; | ||
} | ||
const errorMessage = 'Please ensure that function response turn comes immediately after a function call turn.'; | ||
if (contents.length < 2) { | ||
throw new errors_1.ClientError(errorMessage); | ||
} | ||
const secondLastestContentPart = contents[contents.length - 2].parts[0]; | ||
if (!('functionCall' in secondLastestContentPart)) { | ||
throw new errors_1.ClientError(errorMessage); | ||
} | ||
} | ||
function validateGenerateContentRequest(request) { | ||
validateGcsInput(request.contents); | ||
validateFunctionResponseRequest(request.contents); | ||
} | ||
function validateGenerationConfig(generation_config) { | ||
if ('top_k' in generation_config) { | ||
if (!(generation_config.top_k > 0) || !(generation_config.top_k <= 40)) { | ||
delete generation_config.top_k; | ||
} | ||
} | ||
return generation_config; | ||
} | ||
function formatContentRequest(request, generation_config, safety_settings) { | ||
if (typeof request === 'string') { | ||
return { | ||
contents: [{ role: util_1.constants.USER_ROLE, parts: [{ text: request }] }], | ||
generation_config: generation_config, | ||
safety_settings: safety_settings, | ||
}; | ||
} | ||
else { | ||
return request; | ||
} | ||
} | ||
//# sourceMappingURL=index.js.map |
@@ -17,3 +17,3 @@ /** | ||
*/ | ||
import { GoogleAuthOptions } from 'google-auth-library'; | ||
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; | ||
/** | ||
@@ -59,2 +59,26 @@ * Params used to initialize the Vertex SDK | ||
/** | ||
* @property {string} model - model name | ||
* @property {string} project - project The Google Cloud project to use for the request | ||
* @property {string} location - The Google Cloud project location to use for the request | ||
* @property {GoogleAuth} googleAuth - GoogleAuth class instance that handles authentication. | ||
* Details about GoogleAuth is referred to https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
* @property {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
* @property {GenerationConfig} [generation_config] - {@link | ||
* GenerationConfig} | ||
* @property {SafetySetting[]} [safety_settings] - {@link SafetySetting} | ||
* @property {Tool[]} [tools] - {@link Tool} | ||
*/ | ||
export declare interface GetGenerativeModelParams extends ModelParams { | ||
model: string; | ||
project: string; | ||
location: string; | ||
googleAuth: GoogleAuth; | ||
apiEndpoint?: string; | ||
generation_config?: GenerationConfig; | ||
safety_settings?: SafetySetting[]; | ||
tools?: Tool[]; | ||
} | ||
/** | ||
* Configuration for initializing a model, for example via getGenerativeModel | ||
@@ -430,3 +454,4 @@ * @property {string} model - model name. | ||
* @property {Content} - content. {@link Content} | ||
* @property {number} - [index]. The index of the candidate in the {@link GenerateContentResponse} | ||
* @property {number} - [index]. The index of the candidate in the {@link | ||
* GenerateContentResponse} | ||
* @property {FinishReason} - [finishReason]. {@link FinishReason} | ||
@@ -436,2 +461,4 @@ * @property {string} - [finishMessage]. | ||
* @property {CitationMetadata} - [citationMetadata]. {@link CitationMetadata} | ||
* @property {GroundingMetadata} - [groundingMetadata]. {@link | ||
* GroundingMetadata} | ||
*/ | ||
@@ -445,2 +472,3 @@ export declare interface GenerateContentCandidate { | ||
citationMetadata?: CitationMetadata; | ||
groundingMetadata?: GroundingMetadata; | ||
functionCall?: FunctionCall; | ||
@@ -469,2 +497,49 @@ } | ||
/** | ||
* A collection of grounding attributions for a piece of content. | ||
* @property {string[]} - [webSearchQueries]. Web search queries for the | ||
* following-up web search. | ||
* @property {GroundingAttribution[]} - [groundingAttributions]. Array of {@link | ||
* GroundingAttribution} | ||
*/ | ||
export declare interface GroundingMetadata { | ||
webSearchQueries?: string[]; | ||
groundingAttributions?: GroundingAttribution[]; | ||
} | ||
/** | ||
* Grounding attribution. | ||
* @property {GroundingAttributionWeb} - [web] Attribution from the web. | ||
* @property {GroundingAttributionSegment} - [segment] Segment of the content | ||
* this attribution belongs to. | ||
* @property {number} - [confidenceScore] Confidence score of the attribution. | ||
* Ranges from 0 to 1. 1 is the most confident. | ||
*/ | ||
export declare interface GroundingAttribution { | ||
web?: GroundingAttributionWeb; | ||
segment?: GroundingAttributionSegment; | ||
confidenceScore?: number; | ||
} | ||
/** | ||
* Segment of the content this attribution belongs to. | ||
* @property {number} - [part_index] The index of a Part object within its | ||
* parent Content object. | ||
* @property {number} - [startIndex] Start index in the given Part, measured in | ||
* bytes. Offset from the start of the Part, inclusive, starting at zero. | ||
* @property {number} - [endIndex] End index in the given Part, measured in | ||
* bytes. Offset from the start of the Part, exclusive, starting at zero. | ||
*/ | ||
export declare interface GroundingAttributionSegment { | ||
partIndex?: number; | ||
startIndex?: number; | ||
endIndex?: number; | ||
} | ||
/** | ||
* Attribution from the web. | ||
* @property {string} - [uri] URI reference of the attribution. | ||
* @property {string} - [title] Title of the attribution. | ||
*/ | ||
export declare interface GroundingAttributionWeb { | ||
uri?: string; | ||
title?: string; | ||
} | ||
/** | ||
* A predicted FunctionCall returned from the model that contains a string | ||
@@ -527,5 +602,5 @@ * representating the FunctionDeclaration.name with the parameters and their | ||
/** | ||
* A Tool is a piece of code that enables the system to interact with | ||
* external systems to perform an action, or set of actions, outside of | ||
* knowledge and scope of the model. | ||
* A FunctionDeclarationsTool is a piece of code that enables the system to | ||
* interact with external systems to perform an action, or set of actions, | ||
* outside of knowledge and scope of the model. | ||
* @property {object} - function_declarations One or more function declarations | ||
@@ -540,6 +615,43 @@ * to be passed to the model along with the current user query. Model may decide | ||
*/ | ||
export declare interface Tool { | ||
function_declarations: FunctionDeclaration[]; | ||
export declare interface FunctionDeclarationsTool { | ||
function_declarations?: FunctionDeclaration[]; | ||
} | ||
export declare interface RetrievalTool { | ||
retrieval?: Retrieval; | ||
} | ||
export declare interface GoogleSearchRetrievalTool { | ||
googleSearchRetrieval?: GoogleSearchRetrieval; | ||
} | ||
export declare type Tool = FunctionDeclarationsTool | RetrievalTool | GoogleSearchRetrievalTool; | ||
/** | ||
* Defines a retrieval tool that model can call to access external knowledge. | ||
* @property {VertexAISearch} - [vertexAiSearch] Set to use data source powered | ||
by Vertex AI Search. | ||
* @property {boolean} - [disableAttribution] Disable using the result from | ||
this tool in detecting grounding attribution. This does not affect how the | ||
result is given to the model for generation. | ||
*/ | ||
export declare interface Retrieval { | ||
vertexAiSearch?: VertexAISearch; | ||
disableAttribution?: boolean; | ||
} | ||
/** | ||
* Tool to retrieve public web data for grounding, powered by Google. | ||
* @property {boolean} - [disableAttribution] Disable using the result from this | ||
* tool in detecting grounding attribution. This does not affect how the result | ||
* is given to the model for generation. | ||
*/ | ||
export declare interface GoogleSearchRetrieval { | ||
disableAttribution?: boolean; | ||
} | ||
/** | ||
* Retrieve from Vertex AI Search datastore for grounding. See | ||
https://cloud.google.com/vertex-ai-search-and-conversation | ||
* @property {string} - [datastore] Fully-qualified Vertex AI Search's datastore | ||
resource ID. projects/<>/locations/<>/collections/<>/dataStores/<> | ||
*/ | ||
export declare interface VertexAISearch { | ||
datastore: string; | ||
} | ||
/** | ||
* Contains the list of OpenAPI data types | ||
@@ -588,1 +700,25 @@ * as defined by https://swagger.io/docs/specification/data-models/data-types/ | ||
} | ||
/** | ||
* Params to initiate a multiturn chat with the model via startChat | ||
* @property {Content[]} - [history] history of the chat session. {@link Content} | ||
* @property {SafetySetting[]} - [safety_settings] Array of {@link SafetySetting} | ||
* @property {GenerationConfig} - [generation_config] {@link GenerationConfig} | ||
*/ | ||
export declare interface StartChatParams { | ||
history?: Content[]; | ||
safety_settings?: SafetySetting[]; | ||
generation_config?: GenerationConfig; | ||
tools?: Tool[]; | ||
api_endpoint?: string; | ||
} | ||
/** | ||
* All params passed to initiate multiturn chat via startChat | ||
* @property {string} project - project The Google Cloud project to use for the request | ||
* @property {string} location - The Google Cloud project location to use for the request | ||
*/ | ||
export declare interface StartChatSessionRequest extends StartChatParams { | ||
project: string; | ||
location: string; | ||
googleAuth: GoogleAuth; | ||
publisher_model_endpoint: string; | ||
} |
@@ -22,2 +22,3 @@ /** | ||
export declare const FUNCTION_ROLE = "function"; | ||
export declare const USER_AGENT = "model-builder/0.3.1 grpc-node/0.3.1"; | ||
export declare const USER_AGENT = "model-builder/0.4.0 grpc-node/0.4.0"; | ||
export declare const CREDENTIAL_ERROR_MESSAGE = "\nUnable to authenticate your request \nDepending on your run time environment, you can get authentication by \n- if in local instance or cloud shell: `!gcloud auth login` \n- if in Colab: \n -`from google.colab import auth` \n -`auth.authenticate_user()` \n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication"; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.USER_AGENT = exports.FUNCTION_ROLE = exports.MODEL_ROLE = exports.USER_ROLE = exports.STREAMING_GENERATE_CONTENT_METHOD = exports.GENERATE_CONTENT_METHOD = void 0; | ||
exports.CREDENTIAL_ERROR_MESSAGE = exports.USER_AGENT = exports.FUNCTION_ROLE = exports.MODEL_ROLE = exports.USER_ROLE = exports.STREAMING_GENERATE_CONTENT_METHOD = exports.GENERATE_CONTENT_METHOD = void 0; | ||
/** | ||
@@ -26,5 +26,12 @@ * @license | ||
const USER_AGENT_PRODUCT = 'model-builder'; | ||
const CLIENT_LIBRARY_VERSION = '0.3.1'; // x-release-please-version | ||
const CLIENT_LIBRARY_VERSION = '0.4.0'; // x-release-please-version | ||
const CLIENT_LIBRARY_LANGUAGE = `grpc-node/${CLIENT_LIBRARY_VERSION}`; | ||
exports.USER_AGENT = `${USER_AGENT_PRODUCT}/${CLIENT_LIBRARY_VERSION} ${CLIENT_LIBRARY_LANGUAGE}`; | ||
exports.CREDENTIAL_ERROR_MESSAGE = '\nUnable to authenticate your request\ | ||
\nDepending on your run time environment, you can get authentication by\ | ||
\n- if in local instance or cloud shell: `!gcloud auth login`\ | ||
\n- if in Colab:\ | ||
\n -`from google.colab import auth`\ | ||
\n -`auth.authenticate_user()`\ | ||
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication'; | ||
//# sourceMappingURL=constants.js.map |
@@ -18,2 +18,1 @@ /** | ||
export * as constants from './constants'; | ||
export { postRequest } from './post_request'; |
@@ -19,6 +19,4 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.postRequest = exports.constants = void 0; | ||
exports.constants = void 0; | ||
exports.constants = require("./constants"); | ||
var post_request_1 = require("./post_request"); | ||
Object.defineProperty(exports, "postRequest", { enumerable: true, get: function () { return post_request_1.postRequest; } }); | ||
//# sourceMappingURL=index.js.map |
@@ -71,2 +71,9 @@ "use strict"; | ||
]; | ||
const TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL = [ | ||
{ | ||
googleSearchRetrieval: { | ||
disableAttribution: false, | ||
}, | ||
}, | ||
]; | ||
const WEATHER_FORECAST = 'super nice'; | ||
@@ -92,3 +99,3 @@ const FUNCTION_RESPONSE_PART = [ | ||
}); | ||
const generativeTextModel = vertex_ai.preview.getGenerativeModel({ | ||
const generativeTextModel = vertex_ai.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
@@ -99,3 +106,9 @@ generation_config: { | ||
}); | ||
const generativeTextModelWithPrefix = vertex_ai.preview.getGenerativeModel({ | ||
const generativeTextModelPreview = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const generativeTextModelWithPrefix = vertex_ai.getGenerativeModel({ | ||
model: 'models/gemini-pro', | ||
@@ -106,11 +119,26 @@ generation_config: { | ||
}); | ||
const textModelNoOutputLimit = vertex_ai.preview.getGenerativeModel({ | ||
const generativeTextModelWithPrefixPreview = vertex_ai.preview.getGenerativeModel({ | ||
model: 'models/gemini-pro', | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const textModelNoOutputLimit = vertex_ai.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeVisionModel = vertex_ai.preview.getGenerativeModel({ | ||
const textModelNoOutputLimitPreview = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeVisionModel = vertex_ai.getGenerativeModel({ | ||
model: 'gemini-pro-vision', | ||
}); | ||
const generativeVisionModelWithPrefix = vertex_ai.preview.getGenerativeModel({ | ||
const generativeVisionModelPreview = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-pro-vision', | ||
}); | ||
const generativeVisionModelWithPrefix = vertex_ai.getGenerativeModel({ | ||
model: 'models/gemini-pro-vision', | ||
}); | ||
const generativeVisionModelWithPrefixPreview = vertex_ai.preview.getGenerativeModel({ | ||
model: 'models/gemini-pro-vision', | ||
}); | ||
describe('generateContentStream', () => { | ||
@@ -128,2 +156,10 @@ beforeEach(() => { | ||
}); | ||
it('in preview should should return a stream and aggregated response when passed text', async () => { | ||
const streamingResp = await generativeTextModelPreview.generateContentStream(TEXT_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview, for item ${item}`); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}`); | ||
}); | ||
it('should not return a invalid unicode', async () => { | ||
@@ -144,2 +180,17 @@ const streamingResp = await generativeTextModel.generateContentStream({ | ||
}); | ||
it('in preview should not return a invalid unicode', async () => { | ||
const streamingResp = await generativeTextModelPreview.generateContentStream({ | ||
contents: [{ role: 'user', parts: [{ text: '创作一首古诗' }] }], | ||
}); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview, for item ${item}`); | ||
for (const candidate of item.candidates) { | ||
for (const part of candidate.content.parts) { | ||
expect(part.text).not.toContain('\ufffd', `sys test failure on generateContentStream in preview, for item ${item}`); | ||
} | ||
} | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}`); | ||
}); | ||
it('should return a stream and aggregated response when passed multipart base64 content', async () => { | ||
@@ -153,2 +204,10 @@ const streamingResp = await generativeVisionModel.generateContentStream(MULTI_PART_BASE64_REQUEST); | ||
}); | ||
it('in preview should return a stream and aggregated response when passed multipart base64 content', async () => { | ||
const streamingResp = await generativeVisionModelPreview.generateContentStream(MULTI_PART_BASE64_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview, for item ${item}`); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}`); | ||
}); | ||
it('should throw ClientError when having invalid input', async () => { | ||
@@ -168,6 +227,26 @@ const badRequest = { | ||
expect(e).toBeInstanceOf(src_1.ClientError); | ||
expect(e.message).toBe('[VertexAI.ClientError]: got status: 400 Bad Request', `sys test failure on generateContentStream when having bad request | ||
expect(e.message).toContain('[VertexAI.ClientError]: got status: 400 Bad Request', `sys test failure on generateContentStream when having bad request | ||
got wrong error message: ${e.message}`); | ||
}); | ||
}); | ||
it('in preview should throw ClientError when having invalid input', async () => { | ||
const badRequest = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [ | ||
{ text: 'describe this image:' }, | ||
{ inline_data: { mime_type: 'image/png', data: 'invalid data' } }, | ||
], | ||
}, | ||
], | ||
}; | ||
await generativeVisionModelPreview | ||
.generateContentStream(badRequest) | ||
.catch(e => { | ||
expect(e).toBeInstanceOf(src_1.ClientError); | ||
expect(e.message).toContain('[VertexAI.ClientError]: got status: 400 Bad Request', `sys test failure on generateContentStream in preview when having bad request | ||
got wrong error message: ${e.message}`); | ||
}); | ||
}); | ||
it('should should return a stream and aggregated response when passed multipart GCS content', async () => { | ||
@@ -181,2 +260,10 @@ const streamingResp = await generativeVisionModel.generateContentStream(MULTI_PART_GCS_REQUEST); | ||
}); | ||
it('in preview should should return a stream and aggregated response when passed multipart GCS content', async () => { | ||
const streamingResp = await generativeVisionModelPreview.generateContentStream(MULTI_PART_GCS_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview, for item ${item}`); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}`); | ||
}); | ||
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
@@ -198,2 +285,18 @@ var _a; | ||
}); | ||
it('in preview should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
var _a; | ||
const request = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weather in Boston?' }] }, | ||
{ role: 'model', parts: FUNCTION_CALL }, | ||
{ role: 'function', parts: FUNCTION_RESPONSE_PART }, | ||
], | ||
tools: TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const streamingResp = await generativeTextModelPreview.generateContentStream(request); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview, for item ${item}`); | ||
expect((_a = item.candidates[0].content.parts[0].text) === null || _a === void 0 ? void 0 : _a.toLowerCase()).toContain(WEATHER_FORECAST); | ||
} | ||
}); | ||
}); | ||
@@ -206,5 +309,27 @@ describe('generateContent', () => { | ||
const response = await generativeTextModel.generateContent(TEXT_REQUEST); | ||
const aggregatedResp = await response.response; | ||
const aggregatedResp = response.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream for aggregated response: ${aggregatedResp}`); | ||
}); | ||
it('in preview should return the aggregated response', async () => { | ||
const response = await generativeTextModelPreview.generateContent(TEXT_REQUEST); | ||
const aggregatedResp = response.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}`); | ||
}); | ||
xit('should return grounding metadata when passed GoogleSearchRetriever or Retriever', async () => { | ||
const generativeTextModel = vertex_ai.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
//tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, | ||
}); | ||
const result = await generativeTextModel.generateContent({ | ||
contents: [{ role: 'user', parts: [{ text: 'Why is the sky blue?' }] }], | ||
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, | ||
}); | ||
const response = result.response; | ||
const groundingMetadata = response.candidates[0].groundingMetadata; | ||
expect(groundingMetadata).toBeDefined(); | ||
if (groundingMetadata) { | ||
// expect(groundingMetadata.groundingAttributions).toBeTruthy(); | ||
expect(groundingMetadata.webSearchQueries).toBeTruthy(); | ||
} | ||
}); | ||
}); | ||
@@ -219,6 +344,14 @@ describe('sendMessage', () => { | ||
const result1 = await chat.sendMessage(chatInput1); | ||
const response1 = await result1.response; | ||
const response1 = result1.response; | ||
expect(response1.candidates[0]).toBeTruthy(`sys test failure on sendMessage for aggregated response: ${response1}`); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
it('in preview should populate history and return a chat response', async () => { | ||
const chat = generativeTextModelPreview.startChat(); | ||
const chatInput1 = 'How can I learn more about Node.js?'; | ||
const result1 = await chat.sendMessage(chatInput1); | ||
const response1 = result1.response; | ||
expect(response1.candidates[0]).toBeTruthy(`sys test failure on sendMessage in preview for aggregated response: ${response1}`); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
}); | ||
@@ -244,2 +377,17 @@ describe('sendMessageStream', () => { | ||
}); | ||
it('in preview should should return a stream and populate history when generation_config is passed to startChat', async () => { | ||
const chat = generativeTextModelPreview.startChat({ | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const chatInput1 = 'How can I learn more about Node.js?'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
for await (const item of result1.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream in preview, for item ${item}`); | ||
} | ||
const resp = await result1.response; | ||
expect(resp.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream in preview for aggregated response: ${resp}`); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
it('should should return a stream and populate history when startChat is passed no request obj', async () => { | ||
@@ -256,2 +404,13 @@ const chat = generativeTextModel.startChat(); | ||
}); | ||
it('in preview should should return a stream and populate history when startChat is passed no request obj', async () => { | ||
const chat = generativeTextModelPreview.startChat(); | ||
const chatInput1 = 'How can I learn more about Node.js?'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
for await (const item of result1.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream in preview, for item ${item}`); | ||
} | ||
const resp = await result1.response; | ||
expect(resp.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream in preview for aggregated response: ${resp}`); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
it('should return chunks as they come in', async () => { | ||
@@ -273,2 +432,18 @@ const chat = textModelNoOutputLimit.startChat({}); | ||
}); | ||
it('in preview should return chunks as they come in', async () => { | ||
const chat = textModelNoOutputLimitPreview.startChat({}); | ||
const chatInput1 = 'Tell me a story in 3000 words'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
let firstChunkTimestamp = 0; | ||
let aggregatedResultTimestamp = 0; | ||
const firstChunkFinalResultTimeDiff = 200; // ms | ||
for await (const item of result1.stream) { | ||
if (firstChunkTimestamp === 0) { | ||
firstChunkTimestamp = Date.now(); | ||
} | ||
} | ||
await result1.response; | ||
aggregatedResultTimestamp = Date.now(); | ||
expect(aggregatedResultTimestamp - firstChunkTimestamp).toBeGreaterThan(firstChunkFinalResultTimeDiff); | ||
}); | ||
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
@@ -294,2 +469,22 @@ const chat = generativeTextModel.startChat({ | ||
}); | ||
it('in preview should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const chat = generativeTextModelPreview.startChat({ | ||
tools: TOOLS_WITH_FUNCTION_DECLARATION, | ||
}); | ||
const chatInput1 = 'What is the weather in Boston?'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
for await (const item of result1.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream in preview with function calling, for item ${item}`); | ||
} | ||
const response1 = await result1.response; | ||
expect(JSON.stringify(response1.candidates[0].content.parts[0].functionCall)).toContain(FUNCTION_CALL_NAME); | ||
expect(JSON.stringify(response1.candidates[0].content.parts[0].functionCall)).toContain('location'); | ||
// Send a follow up message with a FunctionResponse | ||
const result2 = await chat.sendMessageStream(FUNCTION_RESPONSE_PART); | ||
for await (const item of result2.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on sendMessageStream in preview with function calling, for item ${item}`); | ||
} | ||
const response2 = await result2.response; | ||
expect(JSON.stringify(response2.candidates[0].content.parts[0].text)).toContain(WEATHER_FORECAST); | ||
}); | ||
}); | ||
@@ -301,2 +496,6 @@ describe('countTokens', () => { | ||
}); | ||
it('in preview should should return a CountTokensResponse', async () => { | ||
const countTokensResp = await generativeTextModelPreview.countTokens(TEXT_REQUEST); | ||
expect(countTokensResp.totalTokens).toBeTruthy(`sys test failure on countTokens in preview, ${countTokensResp}`); | ||
}); | ||
}); | ||
@@ -315,2 +514,10 @@ describe('generateContentStream using models/model-id', () => { | ||
}); | ||
it('in preview should should return a stream and aggregated response when passed text', async () => { | ||
const streamingResp = await generativeTextModelWithPrefixPreview.generateContentStream(TEXT_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview using models/gemini-pro, for item ${item}`); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview using models/gemini-pro for aggregated response: ${aggregatedResp}`); | ||
}); | ||
it('should return a stream and aggregated response when passed multipart base64 content when using models/gemini-pro-vision', async () => { | ||
@@ -324,3 +531,11 @@ const streamingResp = await generativeVisionModelWithPrefix.generateContentStream(MULTI_PART_BASE64_REQUEST); | ||
}); | ||
it('in preview should return a stream and aggregated response when passed multipart base64 content when using models/gemini-pro-vision', async () => { | ||
const streamingResp = await generativeVisionModelWithPrefixPreview.generateContentStream(MULTI_PART_BASE64_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview using models/gemini-pro-vision, for item ${item}`); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy(`sys test failure on generateContentStream in preview using models/gemini-pro-vision for aggregated response: ${aggregatedResp}`); | ||
}); | ||
}); | ||
//# sourceMappingURL=end_to_end_sample_test.js.map |
@@ -17,8 +17,2 @@ /** | ||
*/ | ||
import { GenerateContentResponse } from '../src/types/content'; | ||
/** | ||
* Returns a generator, used to mock the generateContentStream response | ||
* @ignore | ||
*/ | ||
export declare function testGenerator(): AsyncGenerator<GenerateContentResponse>; | ||
export declare function testGeneratorWithEmptyResponse(): AsyncGenerator<GenerateContentResponse>; | ||
export {}; |
@@ -19,957 +19,14 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.testGeneratorWithEmptyResponse = exports.testGenerator = void 0; | ||
/* tslint:disable */ | ||
const index_1 = require("../src/index"); | ||
const StreamFunctions = require("../src/process_stream"); | ||
const content_1 = require("../src/types/content"); | ||
const errors_1 = require("../src/types/errors"); | ||
const util_1 = require("../src/util"); | ||
const PROJECT = 'test_project'; | ||
const LOCATION = 'test_location'; | ||
const TEST_CHAT_MESSSAGE_TEXT = 'How are you doing today?'; | ||
const TEST_USER_CHAT_MESSAGE = [ | ||
{ role: util_1.constants.USER_ROLE, parts: [{ text: TEST_CHAT_MESSSAGE_TEXT }] }, | ||
]; | ||
const TEST_TOKEN = 'testtoken'; | ||
const TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE = [ | ||
{ | ||
role: util_1.constants.USER_ROLE, | ||
parts: [ | ||
{ text: TEST_CHAT_MESSSAGE_TEXT }, | ||
{ | ||
file_data: { | ||
file_uri: 'gs://test_bucket/test_image.jpeg', | ||
mime_type: 'image/jpeg', | ||
}, | ||
}, | ||
], | ||
}, | ||
]; | ||
const TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE = [ | ||
{ | ||
role: util_1.constants.USER_ROLE, | ||
parts: [ | ||
{ text: TEST_CHAT_MESSSAGE_TEXT }, | ||
{ file_data: { file_uri: 'test_image.jpeg', mime_type: 'image/jpeg' } }, | ||
], | ||
}, | ||
]; | ||
const TEST_SAFETY_SETTINGS = [ | ||
{ | ||
category: content_1.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | ||
threshold: content_1.HarmBlockThreshold.BLOCK_ONLY_HIGH, | ||
}, | ||
]; | ||
const TEST_SAFETY_RATINGS = [ | ||
{ | ||
category: content_1.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | ||
probability: content_1.HarmProbability.NEGLIGIBLE, | ||
}, | ||
]; | ||
const TEST_GENERATION_CONFIG = { | ||
candidate_count: 1, | ||
stop_sequences: ['hello'], | ||
}; | ||
const TEST_CANDIDATES = [ | ||
{ | ||
index: 1, | ||
content: { | ||
role: util_1.constants.MODEL_ROLE, | ||
parts: [{ text: 'Im doing great! How are you?' }], | ||
}, | ||
finishReason: content_1.FinishReason.STOP, | ||
finishMessage: '', | ||
safetyRatings: TEST_SAFETY_RATINGS, | ||
citationMetadata: { | ||
citationSources: [ | ||
{ | ||
startIndex: 367, | ||
endIndex: 491, | ||
uri: 'https://www.numerade.com/ask/question/why-does-the-uncertainty-principle-make-it-impossible-to-predict-a-trajectory-for-the-clectron-95172/', | ||
}, | ||
], | ||
}, | ||
}, | ||
]; | ||
const TEST_MODEL_RESPONSE = { | ||
candidates: TEST_CANDIDATES, | ||
usage_metadata: { prompt_token_count: 0, candidates_token_count: 0 }, | ||
}; | ||
const TEST_FUNCTION_CALL_RESPONSE = { | ||
functionCall: { | ||
name: 'get_current_weather', | ||
args: { | ||
location: 'LA', | ||
unit: 'fahrenheit', | ||
}, | ||
}, | ||
}; | ||
const TEST_CANDIDATES_WITH_FUNCTION_CALL = [ | ||
{ | ||
index: 1, | ||
content: { | ||
role: util_1.constants.MODEL_ROLE, | ||
parts: [TEST_FUNCTION_CALL_RESPONSE], | ||
}, | ||
finishReason: content_1.FinishReason.STOP, | ||
finishMessage: '', | ||
safetyRatings: TEST_SAFETY_RATINGS, | ||
}, | ||
]; | ||
const TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL = { | ||
candidates: TEST_CANDIDATES_WITH_FUNCTION_CALL, | ||
}; | ||
const TEST_FUNCTION_RESPONSE_PART = [ | ||
{ | ||
functionResponse: { | ||
name: 'get_current_weather', | ||
response: { name: 'get_current_weather', content: { weather: 'super nice' } }, | ||
}, | ||
}, | ||
]; | ||
const TEST_CANDIDATES_MISSING_ROLE = [ | ||
{ | ||
index: 1, | ||
content: { parts: [{ text: 'Im doing great! How are you?' }] }, | ||
finish_reason: 0, | ||
finish_message: '', | ||
safety_ratings: TEST_SAFETY_RATINGS, | ||
}, | ||
]; | ||
const TEST_MODEL_RESPONSE_MISSING_ROLE = { | ||
candidates: TEST_CANDIDATES_MISSING_ROLE, | ||
usage_metadata: { prompt_token_count: 0, candidates_token_count: 0 }, | ||
}; | ||
const TEST_EMPTY_MODEL_RESPONSE = { | ||
candidates: [], | ||
}; | ||
const TEST_ENDPOINT_BASE_PATH = 'test.googleapis.com'; | ||
const TEST_GCS_FILENAME = 'gs://test_bucket/test_image.jpeg'; | ||
const TEST_MULTIPART_MESSAGE = [ | ||
{ | ||
role: util_1.constants.USER_ROLE, | ||
parts: [ | ||
{ text: 'What is in this picture?' }, | ||
{ file_data: { file_uri: TEST_GCS_FILENAME, mime_type: 'image/jpeg' } }, | ||
], | ||
}, | ||
]; | ||
const BASE_64_IMAGE = 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='; | ||
const INLINE_DATA_FILE_PART = { | ||
inline_data: { | ||
data: BASE_64_IMAGE, | ||
mime_type: 'image/jpeg', | ||
}, | ||
}; | ||
const TEST_MULTIPART_MESSAGE_BASE64 = [ | ||
{ | ||
role: util_1.constants.USER_ROLE, | ||
parts: [{ text: 'What is in this picture?' }, INLINE_DATA_FILE_PART], | ||
}, | ||
]; | ||
const TEST_TOOLS_WITH_FUNCTION_DECLARATION = [ | ||
{ | ||
function_declarations: [ | ||
{ | ||
name: 'get_current_weather', | ||
description: 'get weather in a given location', | ||
parameters: { | ||
type: content_1.FunctionDeclarationSchemaType.OBJECT, | ||
properties: { | ||
location: { type: content_1.FunctionDeclarationSchemaType.STRING }, | ||
unit: { | ||
type: content_1.FunctionDeclarationSchemaType.STRING, | ||
enum: ['celsius', 'fahrenheit'], | ||
}, | ||
}, | ||
required: ['location'], | ||
}, | ||
}, | ||
], | ||
}, | ||
]; | ||
const fetchResponseObj = { | ||
status: 200, | ||
statusText: 'OK', | ||
ok: true, | ||
headers: { 'Content-Type': 'application/json' }, | ||
url: 'url', | ||
}; | ||
/** | ||
* Returns a generator, used to mock the generateContentStream response | ||
* @ignore | ||
*/ | ||
async function* testGenerator() { | ||
yield { | ||
candidates: TEST_CANDIDATES, | ||
}; | ||
} | ||
exports.testGenerator = testGenerator; | ||
async function* testGeneratorWithEmptyResponse() { | ||
yield { | ||
candidates: [], | ||
}; | ||
} | ||
exports.testGeneratorWithEmptyResponse = testGeneratorWithEmptyResponse; | ||
describe('VertexAI', () => { | ||
let vertexai; | ||
let model; | ||
let expectedStreamResult; | ||
let fetchSpy; | ||
beforeEach(() => { | ||
vertexai = new index_1.VertexAI({ | ||
describe('SDK', () => { | ||
it('should import VertexAI', () => { | ||
const PROJECT = 'test_project'; | ||
const LOCATION = 'test_location'; | ||
const vertexai = new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
model = vertexai.preview.getGenerativeModel({ model: 'gemini-pro' }); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
const fetchResult = new Response(JSON.stringify(expectedStreamResult), fetchResponseObj); | ||
fetchSpy = spyOn(global, 'fetch').and.resolveTo(fetchResult); | ||
}); | ||
it('given undefined google auth options, should be instantiated', () => { | ||
expect(vertexai).toBeInstanceOf(index_1.VertexAI); | ||
}); | ||
it('given specified google auth options, should be instantiated', () => { | ||
const googleAuthOptions = { | ||
scopes: 'https://www.googleapis.com/auth/cloud-platform', | ||
}; | ||
const vetexai1 = new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
googleAuthOptions: googleAuthOptions, | ||
}); | ||
expect(vetexai1).toBeInstanceOf(index_1.VertexAI); | ||
}); | ||
it('given inconsistent project ID, should throw error', () => { | ||
const googleAuthOptions = { | ||
projectId: 'another_project', | ||
}; | ||
expect(() => { | ||
new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
googleAuthOptions: googleAuthOptions, | ||
}); | ||
}).toThrow(new Error('inconsistent project ID values. argument project got value test_project but googleAuthOptions.projectId got value another_project')); | ||
}); | ||
it('given scopes missing required scope, should throw GoogleAuthError', () => { | ||
const invalidGoogleAuthOptionsStringScopes = { scopes: 'test.scopes' }; | ||
expect(() => { | ||
new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
googleAuthOptions: invalidGoogleAuthOptionsStringScopes, | ||
}); | ||
}).toThrow(new errors_1.GoogleAuthError("input GoogleAuthOptions.scopes test.scopes doesn't contain required scope " + | ||
'https://www.googleapis.com/auth/cloud-platform, ' + | ||
'please include https://www.googleapis.com/auth/cloud-platform into GoogleAuthOptions.scopes ' + | ||
'or leave GoogleAuthOptions.scopes undefined')); | ||
const invalidGoogleAuthOptionsArrayScopes = { | ||
scopes: ['test1.scopes', 'test2.scopes'], | ||
}; | ||
expect(() => { | ||
new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
googleAuthOptions: invalidGoogleAuthOptionsArrayScopes, | ||
}); | ||
}).toThrow(new errors_1.GoogleAuthError("input GoogleAuthOptions.scopes test1.scopes,test2.scopes doesn't contain required scope " + | ||
'https://www.googleapis.com/auth/cloud-platform, ' + | ||
'please include https://www.googleapis.com/auth/cloud-platform into GoogleAuthOptions.scopes ' + | ||
'or leave GoogleAuthOptions.scopes undefined')); | ||
}); | ||
describe('generateContent', () => { | ||
it('returns a GenerateContentResponse', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContent(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a string', async () => { | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContent(TEST_CHAT_MESSSAGE_TEXT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a GCS URI', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContent(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('raises an error when passed an invalid GCS URI', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE, | ||
}; | ||
await expectAsync(model.generateContent(req)).toBeRejectedWithError(URIError); | ||
}); | ||
it('returns a GenerateContentResponse when passed safety_settings and generation_config', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
safety_settings: TEST_SAFETY_SETTINGS, | ||
generation_config: TEST_GENERATION_CONFIG, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContent(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('updates the base API endpoint when provided', async () => { | ||
const vertexaiWithBasePath = new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
apiEndpoint: TEST_ENDPOINT_BASE_PATH, | ||
}); | ||
model = vertexaiWithBasePath.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
await model.generateContent(req); | ||
expect(fetchSpy.calls.allArgs()[0][0].toString()).toContain(TEST_ENDPOINT_BASE_PATH); | ||
}); | ||
it('default the base API endpoint when base API not provided', async () => { | ||
const vertexaiWithoutBasePath = new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
model = vertexaiWithoutBasePath.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
await model.generateContent(req); | ||
expect(fetchSpy.calls.allArgs()[0][0].toString()).toContain(`${LOCATION}-aiplatform.googleapis.com`); | ||
}); | ||
it('removes top_k when it is set to 0', async () => { | ||
const reqWithEmptyConfigs = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
generation_config: { top_k: 0 }, | ||
safety_settings: [], | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
await model.generateContent(reqWithEmptyConfigs); | ||
const requestArgs = fetchSpy.calls.allArgs()[0][1]; | ||
if (typeof requestArgs === 'object' && requestArgs) { | ||
expect(JSON.stringify(requestArgs['body'])).not.toContain('top_k'); | ||
} | ||
}); | ||
it('includes top_k when it is within 1 - 40', async () => { | ||
const reqWithEmptyConfigs = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
generation_config: { top_k: 1 }, | ||
safety_settings: [], | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
await model.generateContent(reqWithEmptyConfigs); | ||
const requestArgs = fetchSpy.calls.allArgs()[0][1]; | ||
if (typeof requestArgs === 'object' && requestArgs) { | ||
expect(JSON.stringify(requestArgs['body'])).toContain('top_k'); | ||
} | ||
}); | ||
it('aggregates citation metadata', async () => { | ||
var _a; | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContent(req); | ||
expect((_a = resp.response.candidates[0].citationMetadata) === null || _a === void 0 ? void 0 : _a.citationSources.length).toEqual(TEST_MODEL_RESPONSE.candidates[0].citationMetadata.citationSources | ||
.length); | ||
}); | ||
it('returns a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const req = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] }, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContent(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => { | ||
const req = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] }, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await model.generateContent(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => { | ||
const req = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] }, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await model.generateContent(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('generateContentStream', () => { | ||
it('returns a GenerateContentResponse when passed text content', async () => { | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContentStream(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a string', async () => { | ||
const expectedResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContentStream(TEST_CHAT_MESSSAGE_TEXT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed multi-part content with a GCS URI', async () => { | ||
const req = { | ||
contents: TEST_MULTIPART_MESSAGE, | ||
}; | ||
const expectedResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContentStream(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed multi-part content with base64 data', async () => { | ||
const req = { | ||
contents: TEST_MULTIPART_MESSAGE_BASE64, | ||
}; | ||
const expectedResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContentStream(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const req = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] }, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedStreamResult); | ||
const resp = await model.generateContentStream(req); | ||
expect(resp).toEqual(expectedStreamResult); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => { | ||
const req = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] }, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await model.generateContentStream(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => { | ||
const req = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] }, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await model.generateContentStream(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('startChat', () => { | ||
it('returns a ChatSession when passed a request arg', () => { | ||
const req = { | ||
history: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const resp = model.startChat(req); | ||
expect(resp).toBeInstanceOf(index_1.ChatSession); | ||
}); | ||
it('returns a ChatSession when passed no request arg', () => { | ||
const resp = model.startChat(); | ||
expect(resp).toBeInstanceOf(index_1.ChatSession); | ||
}); | ||
}); | ||
}); | ||
describe('countTokens', () => { | ||
it('returns the token count', async () => { | ||
const vertexai = new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
const model = vertexai.preview.getGenerativeModel({ model: 'gemini-pro' }); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const responseBody = { | ||
totalTokens: 1, | ||
}; | ||
const response = new Response(JSON.stringify(responseBody), fetchResponseObj); | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
const resp = await model.countTokens(req); | ||
expect(resp).toEqual(responseBody); | ||
}); | ||
}); | ||
describe('ChatSession', () => { | ||
let chatSession; | ||
let chatSessionWithNoArgs; | ||
let chatSessionWithEmptyResponse; | ||
let chatSessionWithFunctionCall; | ||
let vertexai; | ||
let model; | ||
let expectedStreamResult; | ||
beforeEach(() => { | ||
vertexai = new index_1.VertexAI({ project: PROJECT, location: LOCATION }); | ||
model = vertexai.preview.getGenerativeModel({ model: 'gemini-pro' }); | ||
chatSession = model.startChat({ | ||
history: TEST_USER_CHAT_MESSAGE, | ||
}); | ||
expect(chatSession.history).toEqual(TEST_USER_CHAT_MESSAGE); | ||
chatSessionWithNoArgs = model.startChat(); | ||
chatSessionWithEmptyResponse = model.startChat(); | ||
chatSessionWithFunctionCall = model.startChat({ | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}); | ||
expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
const fetchResult = Promise.resolve(new Response(JSON.stringify(expectedStreamResult), fetchResponseObj)); | ||
spyOn(global, 'fetch').and.returnValue(fetchResult); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
}); | ||
describe('sendMessage', () => { | ||
it('returns a GenerateContentResponse and appends to history', async () => { | ||
const req = 'How are you doing today?'; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
const resp = await chatSession.sendMessage(req); | ||
expect(resp).toEqual(expectedResult); | ||
expect(chatSession.history.length).toEqual(3); | ||
}); | ||
it('returns a GenerateContentResponse and appends to history when startChat is passed with no args', async () => { | ||
const req = 'How are you doing today?'; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
const resp = await chatSessionWithNoArgs.sendMessage(req); | ||
expect(resp).toEqual(expectedResult); | ||
expect(chatSessionWithNoArgs.history.length).toEqual(2); | ||
}); | ||
it('throws an error when the model returns an empty response', async () => { | ||
const req = 'How are you doing today?'; | ||
const expectedResult = { | ||
response: TEST_EMPTY_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
await expectAsync(chatSessionWithEmptyResponse.sendMessage(req)).toBeRejected(); | ||
expect(chatSessionWithEmptyResponse.history.length).toEqual(0); | ||
}); | ||
it('returns a GenerateContentResponse when passed multi-part content', async () => { | ||
const req = TEST_MULTIPART_MESSAGE[0]['parts']; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue(expectedResult); | ||
const resp = await chatSessionWithNoArgs.sendMessage(req); | ||
expect(resp).toEqual(expectedResult); | ||
console.log(chatSessionWithNoArgs.history, 'hihii'); | ||
expect(chatSessionWithNoArgs.history.length).toEqual(2); | ||
}); | ||
it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => { | ||
const functionCallChatMessage = 'What is the weather in LA?'; | ||
const expectedFunctionCallResponse = { | ||
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL, | ||
}; | ||
const expectedResult = { | ||
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL, | ||
}; | ||
const streamSpy = spyOn(StreamFunctions, 'processNonStream'); | ||
streamSpy.and.returnValue(expectedResult); | ||
const response1 = await chatSessionWithFunctionCall.sendMessage(functionCallChatMessage); | ||
expect(response1).toEqual(expectedFunctionCallResponse); | ||
expect(chatSessionWithFunctionCall.history.length).toEqual(2); | ||
// Send a follow-up message with a FunctionResponse | ||
const expectedFollowUpResponse = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
const expectedFollowUpResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
streamSpy.and.returnValue(expectedFollowUpResult); | ||
const response2 = await chatSessionWithFunctionCall.sendMessage(TEST_FUNCTION_RESPONSE_PART); | ||
expect(response2).toEqual(expectedFollowUpResponse); | ||
expect(chatSessionWithFunctionCall.history.length).toEqual(4); | ||
}); | ||
it('throw ClientError when request has no content', async () => { | ||
const expectedErrorMessage = '[VertexAI.ClientError]: No content is provided for sending chat message.'; | ||
await chatSessionWithNoArgs.sendMessage([]).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throw ClientError when request mix functionCall part with other types of part', async () => { | ||
const chatRequest = [ | ||
'what is the weather like in LA', | ||
TEST_FUNCTION_RESPONSE_PART[0], | ||
]; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.'; | ||
await chatSessionWithNoArgs.sendMessage(chatRequest).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('sendMessageStream', () => { | ||
it('returns a StreamGenerateContentResponse and appends to history', async () => { | ||
const req = 'How are you doing today?'; | ||
const expectedResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
const chatSession = model.startChat({ | ||
history: [ | ||
{ | ||
role: util_1.constants.USER_ROLE, | ||
parts: [{ text: 'How are you doing today?' }], | ||
}, | ||
], | ||
}); | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
expect(chatSession.history.length).toEqual(1); | ||
expect(chatSession.history[0].role).toEqual(util_1.constants.USER_ROLE); | ||
const result = await chatSession.sendMessageStream(req); | ||
const response = await result.response; | ||
const expectedResponse = await expectedResult.response; | ||
expect(response).toEqual(expectedResponse); | ||
expect(chatSession.history.length).toEqual(3); | ||
expect(chatSession.history[0].role).toEqual(util_1.constants.USER_ROLE); | ||
expect(chatSession.history[1].role).toEqual(util_1.constants.USER_ROLE); | ||
expect(chatSession.history[2].role).toEqual(util_1.constants.MODEL_ROLE); | ||
}); | ||
it('returns a StreamGenerateContentResponse and appends role if missing', async () => { | ||
const req = 'How are you doing today?'; | ||
const expectedResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE_MISSING_ROLE), | ||
stream: testGenerator(), | ||
}; | ||
const chatSession = model.startChat({ | ||
history: [ | ||
{ | ||
role: util_1.constants.USER_ROLE, | ||
parts: [{ text: 'How are you doing today?' }], | ||
}, | ||
], | ||
}); | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
expect(chatSession.history.length).toEqual(1); | ||
expect(chatSession.history[0].role).toEqual(util_1.constants.USER_ROLE); | ||
const result = await chatSession.sendMessageStream(req); | ||
const response = await result.response; | ||
const expectedResponse = await expectedResult.response; | ||
expect(response).toEqual(expectedResponse); | ||
expect(chatSession.history.length).toEqual(3); | ||
expect(chatSession.history[0].role).toEqual(util_1.constants.USER_ROLE); | ||
expect(chatSession.history[1].role).toEqual(util_1.constants.USER_ROLE); | ||
expect(chatSession.history[2].role).toEqual(util_1.constants.MODEL_ROLE); | ||
}); | ||
it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => { | ||
const functionCallChatMessage = 'What is the weather in LA?'; | ||
const expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL), | ||
stream: testGenerator(), | ||
}; | ||
const streamSpy = spyOn(StreamFunctions, 'processStream'); | ||
streamSpy.and.returnValue(expectedStreamResult); | ||
const response1 = await chatSessionWithFunctionCall.sendMessageStream(functionCallChatMessage); | ||
expect(response1).toEqual(expectedStreamResult); | ||
expect(chatSessionWithFunctionCall.history.length).toEqual(2); | ||
// Send a follow-up message with a FunctionResponse | ||
const expectedFollowUpStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
streamSpy.and.returnValue(expectedFollowUpStreamResult); | ||
const response2 = await chatSessionWithFunctionCall.sendMessageStream(TEST_FUNCTION_RESPONSE_PART); | ||
expect(response2).toEqual(expectedFollowUpStreamResult); | ||
expect(chatSessionWithFunctionCall.history.length).toEqual(4); | ||
}); | ||
it('throw ClientError when request has no content', async () => { | ||
const expectedErrorMessage = '[VertexAI.ClientError]: No content is provided for sending chat message.'; | ||
await chatSessionWithNoArgs.sendMessageStream([]).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throw ClientError when request mix functionCall part with other types of part', async () => { | ||
const chatRequest = [ | ||
'what is the weather like in LA', | ||
TEST_FUNCTION_RESPONSE_PART[0], | ||
]; | ||
const expectedErrorMessage = '[VertexAI.ClientError]: Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.'; | ||
await chatSessionWithNoArgs.sendMessageStream(chatRequest).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
}); | ||
describe('when exception at fetch', () => { | ||
const expectedErrorMessage = '[VertexAI.GoogleGenerativeAIError]: exception posting request'; | ||
const vertexai = new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
const model = vertexai.preview.getGenerativeModel({ model: 'gemini-pro' }); | ||
const chatSession = model.startChat(); | ||
const message = 'hi'; | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const countTokenReq = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
beforeEach(() => { | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
spyOn(global, 'fetch').and.throwError('error'); | ||
}); | ||
it('generateContent should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.generateContent(req)).toBeRejected(); | ||
}); | ||
it('generateContentStream should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.generateContentStream(req)).toBeRejected(); | ||
}); | ||
it('sendMessage should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(chatSession.sendMessage(message)).toBeRejected(); | ||
}); | ||
it('countTokens should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); | ||
}); | ||
}); | ||
describe('when response is undefined', () => { | ||
const expectedErrorMessage = '[VertexAI.GoogleGenerativeAIError]: response is undefined'; | ||
const vertexai = new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
const model = vertexai.preview.getGenerativeModel({ model: 'gemini-pro' }); | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const message = 'hi'; | ||
const chatSession = model.startChat(); | ||
const countTokenReq = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
beforeEach(() => { | ||
spyOn(global, 'fetch').and.resolveTo(); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
}); | ||
it('generateContent should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.generateContent(req)).toBeRejected(); | ||
await model.generateContent(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('generateContentStream should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.generateContentStream(req)).toBeRejected(); | ||
await model.generateContentStream(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('sendMessage should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(chatSession.sendMessage(message)).toBeRejected(); | ||
await chatSession.sendMessage(message).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('countTokens should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); | ||
await model.countTokens(countTokenReq).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('when response is 4XX', () => { | ||
const expectedErrorMessage = '[VertexAI.ClientError]: got status: 400 Bad Request'; | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const vertexai = new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
const fetch400Obj = { | ||
status: 400, | ||
statusText: 'Bad Request', | ||
ok: false, | ||
}; | ||
const body = {}; | ||
const response = new Response(JSON.stringify(body), fetch400Obj); | ||
const model = vertexai.preview.getGenerativeModel({ model: 'gemini-pro' }); | ||
const message = 'hi'; | ||
const chatSession = model.startChat(); | ||
const countTokenReq = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
beforeEach(() => { | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
}); | ||
it('generateContent should throw ClientError error', async () => { | ||
await expectAsync(model.generateContent(req)).toBeRejected(); | ||
await model.generateContent(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('generateContentStream should throw ClientError error', async () => { | ||
await expectAsync(model.generateContentStream(req)).toBeRejected(); | ||
await model.generateContentStream(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('sendMessage should throw ClientError error', async () => { | ||
await expectAsync(chatSession.sendMessage(message)).toBeRejected(); | ||
await chatSession.sendMessage(message).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('countTokens should throw ClientError error', async () => { | ||
await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); | ||
await model.countTokens(countTokenReq).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('when response is not OK and not 4XX', () => { | ||
const expectedErrorMessage = '[VertexAI.GoogleGenerativeAIError]: got status: 500 Internal Server Error'; | ||
const req = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const vertexai = new index_1.VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
const fetch500Obj = { | ||
status: 500, | ||
statusText: 'Internal Server Error', | ||
ok: false, | ||
}; | ||
const body = {}; | ||
const response = new Response(JSON.stringify(body), fetch500Obj); | ||
const model = vertexai.preview.getGenerativeModel({ model: 'gemini-pro' }); | ||
const message = 'hi'; | ||
const chatSession = model.startChat(); | ||
const countTokenReq = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
beforeEach(() => { | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
}); | ||
it('generateContent should throws GoogleGenerativeAIError', async () => { | ||
await expectAsync(model.generateContent(req)).toBeRejected(); | ||
await model.generateContent(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('generateContentStream should throws GoogleGenerativeAIError', async () => { | ||
await expectAsync(model.generateContentStream(req)).toBeRejected(); | ||
await model.generateContentStream(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('sendMessage should throws GoogleGenerativeAIError', async () => { | ||
await expectAsync(chatSession.sendMessage(message)).toBeRejected(); | ||
await chatSession.sendMessage(message).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('countTokens should throws GoogleGenerativeAIError', async () => { | ||
await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); | ||
await model.countTokens(countTokenReq).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
//# sourceMappingURL=index_test.js.map |
# Changelog | ||
## [0.4.0](https://github.com/googleapis/nodejs-vertexai/compare/v0.3.1...v0.4.0) (2024-02-15) | ||
### Features | ||
* Added support for Grounding ([929df39](https://github.com/googleapis/nodejs-vertexai/commit/929df39f19f423bcfaf35ef113ce04886345a6ab)) | ||
* enable both GA and preview namespaces. ([1c2aca6](https://github.com/googleapis/nodejs-vertexai/commit/1c2aca6b776784a5b51d1654ffa41dc36f600874)) | ||
### Bug Fixes | ||
* throw more details on error message. ([5dba79c](https://github.com/googleapis/nodejs-vertexai/commit/5dba79c3648203b9a66b6098f9f1fa0280e6e67d)) | ||
* unary api should only need to `await` once. ([67a2e96](https://github.com/googleapis/nodejs-vertexai/commit/67a2e9649c69a2cf9868a074527efd93d2c800c9)) | ||
## [0.3.1](https://github.com/googleapis/nodejs-vertexai/compare/v0.3.0...v0.3.1) (2024-02-06) | ||
@@ -4,0 +18,0 @@ |
{ | ||
"name": "@google-cloud/vertexai", | ||
"description": "Vertex Generative AI client for Node.js", | ||
"version": "0.3.1", | ||
"version": "0.4.0", | ||
"license": "Apache-2.0", | ||
@@ -6,0 +6,0 @@ "author": "Google LLC", |
@@ -36,3 +36,3 @@ # Vertex AI Node.js SDK | ||
// Instantiate models | ||
const generativeModel = vertex_ai.preview.getGenerativeModel({ | ||
const generativeModel = vertex_ai.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
@@ -45,3 +45,3 @@ // The following parameters are optional | ||
const generativeVisionModel = vertex_ai.preview.getGenerativeModel({ | ||
const generativeVisionModel = vertex_ai.getGenerativeModel({ | ||
model: 'gemini-pro-vision', | ||
@@ -91,3 +91,3 @@ }); | ||
const filePart = {file_data: {file_uri: "gs://generativeai-downloads/images/scones.jpg", mime_type: "image/jpeg"}}; | ||
const textPart = {text: 'What is this a picture of?'}; | ||
const textPart = {text: 'What is this picture about?'}; | ||
const request = { | ||
@@ -113,3 +113,3 @@ contents: [{role: 'user', parts: [textPart, filePart]}], | ||
const filePart = {inline_data: {data: base64Image, mime_type: 'image/jpeg'}}; | ||
const textPart = {text: 'What is this a picture of?'}; | ||
const textPart = {text: 'What is this picture about?'}; | ||
const request = { | ||
@@ -116,0 +116,0 @@ contents: [{role: 'user', parts: [textPart, filePart]}], |
@@ -36,4 +36,4 @@ /** | ||
token: Promise<any>, | ||
apiEndpoint: string, | ||
request: CountTokensRequest | ||
request: CountTokensRequest, | ||
apiEndpoint?: string | ||
): Promise<CountTokensResponse> { | ||
@@ -51,4 +51,6 @@ const response = await postRequest({ | ||
}); | ||
throwErrorIfNotOK(response); | ||
await throwErrorIfNotOK(response).catch(e => { | ||
throw e; | ||
}); | ||
return processCountTokenResponse(response); | ||
} |
@@ -17,1 +17,125 @@ /** | ||
*/ | ||
/** | ||
* Make a async call to generate content. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
* @return The GenerateContentResponse object with the response candidates. | ||
*/ | ||
import { | ||
GenerateContentRequest, | ||
GenerateContentResult, | ||
GenerationConfig, | ||
SafetySetting, | ||
StreamGenerateContentResult, | ||
} from '../types/content'; | ||
import {GoogleGenerativeAIError} from '../types/errors'; | ||
import * as constants from '../util/constants'; | ||
import { | ||
processNonStream, | ||
processStream, | ||
throwErrorIfNotOK, | ||
} from './post_fetch_processing'; | ||
import {postRequest} from './post_request'; | ||
import { | ||
formatContentRequest, | ||
validateGenerateContentRequest, | ||
validateGenerationConfig, | ||
} from './pre_fetch_processing'; | ||
export async function generateContent( | ||
location: string, | ||
project: string, | ||
publisherModelEndpoint: string, | ||
token: Promise<any>, | ||
request: GenerateContentRequest | string, | ||
apiEndpoint?: string, | ||
generation_config?: GenerationConfig, | ||
safety_settings?: SafetySetting[] | ||
): Promise<GenerateContentResult> { | ||
request = formatContentRequest(request, generation_config, safety_settings); | ||
validateGenerateContentRequest(request); | ||
if (request.generation_config) { | ||
request.generation_config = validateGenerationConfig( | ||
request.generation_config | ||
); | ||
} | ||
const generateContentRequest: GenerateContentRequest = { | ||
contents: request.contents, | ||
generation_config: request.generation_config ?? generation_config, | ||
safety_settings: request.safety_settings ?? safety_settings, | ||
tools: request.tools ?? [], | ||
}; | ||
const apiVersion = request.tools ? 'v1beta1' : 'v1'; | ||
const response: Response | undefined = await postRequest({ | ||
region: location, | ||
project: project, | ||
resourcePath: publisherModelEndpoint, | ||
resourceMethod: constants.GENERATE_CONTENT_METHOD, | ||
token: await token, | ||
data: generateContentRequest, | ||
apiEndpoint: apiEndpoint, | ||
apiVersion: apiVersion, | ||
}).catch(e => { | ||
throw new GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
await throwErrorIfNotOK(response).catch(e => { | ||
throw e; | ||
}); | ||
return processNonStream(response); | ||
} | ||
/** | ||
* Make an async stream request to generate content. The response will be | ||
* returned in stream. | ||
* @param {GenerateContentRequest} request - {@link GenerateContentRequest} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link | ||
* StreamGenerateContentResult} | ||
*/ | ||
export async function generateContentStream( | ||
location: string, | ||
project: string, | ||
publisherModelEndpoint: string, | ||
token: Promise<any>, | ||
request: GenerateContentRequest | string, | ||
apiEndpoint?: string, | ||
generation_config?: GenerationConfig, | ||
safety_settings?: SafetySetting[] | ||
): Promise<StreamGenerateContentResult> { | ||
request = formatContentRequest(request, generation_config, safety_settings); | ||
validateGenerateContentRequest(request); | ||
if (request.generation_config) { | ||
request.generation_config = validateGenerationConfig( | ||
request.generation_config | ||
); | ||
} | ||
const generateContentRequest: GenerateContentRequest = { | ||
contents: request.contents, | ||
generation_config: request.generation_config ?? generation_config, | ||
safety_settings: request.safety_settings ?? safety_settings, | ||
tools: request.tools ?? [], | ||
}; | ||
const apiVersion = request.tools ? 'v1beta1' : 'v1'; | ||
const response = await postRequest({ | ||
region: location, | ||
project: project, | ||
resourcePath: publisherModelEndpoint, | ||
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD, | ||
token: await token, | ||
data: generateContentRequest, | ||
apiEndpoint: apiEndpoint, | ||
apiVersion: apiVersion, | ||
}).catch(e => { | ||
throw new GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
await throwErrorIfNotOK(response).catch(e => { | ||
throw e; | ||
}); | ||
return processStream(response); | ||
} |
@@ -20,1 +20,2 @@ /** | ||
export {postRequest} from './post_request'; | ||
export {generateContent, generateContentStream} from './generate_content'; |
@@ -17,16 +17,27 @@ /** | ||
*/ | ||
import {CountTokensResponse} from '../types/content'; | ||
import { | ||
CitationSource, | ||
CountTokensResponse, | ||
GenerateContentCandidate, | ||
GenerateContentResponse, | ||
GenerateContentResult, | ||
StreamGenerateContentResult, | ||
} from '../types/content'; | ||
import {ClientError, GoogleGenerativeAIError} from '../types/errors'; | ||
export function throwErrorIfNotOK(response: Response | undefined) { | ||
export async function throwErrorIfNotOK(response: Response | undefined) { | ||
if (response === undefined) { | ||
throw new GoogleGenerativeAIError('response is undefined'); | ||
} | ||
const status: number = response.status; | ||
const statusText: string = response.statusText; | ||
const errorMessage = `got status: ${status} ${statusText}`; | ||
if (status >= 400 && status < 500) { | ||
throw new ClientError(errorMessage); | ||
} | ||
if (!response.ok) { | ||
const status: number = response.status; | ||
const statusText: string = response.statusText; | ||
const errorBody = await response.json(); | ||
const errorMessage = `got status: ${status} ${statusText}. ${JSON.stringify( | ||
errorBody | ||
)}`; | ||
if (status >= 400 && status < 500) { | ||
throw new ClientError(errorMessage); | ||
} | ||
throw new GoogleGenerativeAIError(errorMessage); | ||
@@ -36,10 +47,216 @@ } | ||
const responseLineRE = /^data: (.*)(?:\n\n|\r\r|\r\n\r\n)/; | ||
async function* generateResponseSequence( | ||
stream: ReadableStream<GenerateContentResponse> | ||
): AsyncGenerator<GenerateContentResponse> { | ||
const reader = stream.getReader(); | ||
while (true) { | ||
const {value, done} = await reader.read(); | ||
if (done) { | ||
break; | ||
} | ||
yield value; | ||
} | ||
} | ||
/** | ||
* Process a response.body stream from the backend and return an | ||
* iterator that provides one complete GenerateContentResponse at a time | ||
* and a promise that resolves with a single aggregated | ||
* GenerateContentResponse. | ||
* | ||
* @param response - Response from a fetch call | ||
* @ignore | ||
*/ | ||
export async function processStream( | ||
response: Response | undefined | ||
): Promise<StreamGenerateContentResult> { | ||
if (response === undefined) { | ||
throw new Error('Error processing stream because response === undefined'); | ||
} | ||
if (!response.body) { | ||
throw new Error('Error processing stream because response.body not found'); | ||
} | ||
const inputStream = response.body!.pipeThrough( | ||
new TextDecoderStream('utf8', {fatal: true}) | ||
); | ||
const responseStream = | ||
getResponseStream<GenerateContentResponse>(inputStream); | ||
const [stream1, stream2] = responseStream.tee(); | ||
return Promise.resolve({ | ||
stream: generateResponseSequence(stream1), | ||
response: getResponsePromise(stream2), | ||
}); | ||
} | ||
async function getResponsePromise( | ||
stream: ReadableStream<GenerateContentResponse> | ||
): Promise<GenerateContentResponse> { | ||
const allResponses: GenerateContentResponse[] = []; | ||
const reader = stream.getReader(); | ||
// eslint-disable-next-line no-constant-condition | ||
while (true) { | ||
const {done, value} = await reader.read(); | ||
if (done) { | ||
return aggregateResponses(allResponses); | ||
} | ||
allResponses.push(value); | ||
} | ||
} | ||
/** | ||
* Reads a raw stream from the fetch response and join incomplete | ||
* chunks, returning a new stream that provides a single complete | ||
* GenerateContentResponse in each iteration. | ||
* @ignore | ||
*/ | ||
export function getResponseStream<T>( | ||
inputStream: ReadableStream<string> | ||
): ReadableStream<T> { | ||
const reader = inputStream.getReader(); | ||
const stream = new ReadableStream<T>({ | ||
start(controller) { | ||
let currentText = ''; | ||
return pump(); | ||
function pump(): Promise<(() => Promise<void>) | undefined> { | ||
return reader.read().then(({value, done}) => { | ||
if (done) { | ||
if (currentText.trim()) { | ||
controller.error(new Error('Failed to parse stream')); | ||
return; | ||
} | ||
controller.close(); | ||
return; | ||
} | ||
currentText += value; | ||
let match = currentText.match(responseLineRE); | ||
let parsedResponse: T; | ||
while (match) { | ||
try { | ||
parsedResponse = JSON.parse(match[1]) as T; | ||
} catch (e) { | ||
controller.error( | ||
new Error(`Error parsing JSON response: "${match[1]}"`) | ||
); | ||
return; | ||
} | ||
controller.enqueue(parsedResponse); | ||
currentText = currentText.substring(match[0].length); | ||
match = currentText.match(responseLineRE); | ||
} | ||
return pump(); | ||
}); | ||
} | ||
}, | ||
}); | ||
return stream; | ||
} | ||
/** | ||
* Aggregates an array of `GenerateContentResponse`s into a single | ||
* GenerateContentResponse. | ||
* @ignore | ||
*/ | ||
function aggregateResponses( | ||
responses: GenerateContentResponse[] | ||
): GenerateContentResponse { | ||
const lastResponse = responses[responses.length - 1]; | ||
if (lastResponse === undefined) { | ||
throw new Error( | ||
'Error processing stream because the response is undefined' | ||
); | ||
} | ||
const aggregatedResponse: GenerateContentResponse = { | ||
candidates: [], | ||
promptFeedback: lastResponse.promptFeedback, | ||
}; | ||
for (const response of responses) { | ||
for (let i = 0; i < response.candidates.length; i++) { | ||
if (!aggregatedResponse.candidates[i]) { | ||
aggregatedResponse.candidates[i] = { | ||
index: response.candidates[i].index, | ||
content: { | ||
role: response.candidates[i].content.role, | ||
parts: [{text: ''}], | ||
}, | ||
} as GenerateContentCandidate; | ||
} | ||
if (response.candidates[i].citationMetadata) { | ||
if ( | ||
!aggregatedResponse.candidates[i].citationMetadata?.citationSources | ||
) { | ||
aggregatedResponse.candidates[i].citationMetadata = { | ||
citationSources: [] as CitationSource[], | ||
}; | ||
} | ||
const existingMetadata = response.candidates[i].citationMetadata ?? {}; | ||
if (aggregatedResponse.candidates[i].citationMetadata) { | ||
aggregatedResponse.candidates[i].citationMetadata!.citationSources = | ||
aggregatedResponse.candidates[ | ||
i | ||
].citationMetadata!.citationSources.concat(existingMetadata); | ||
} | ||
} | ||
aggregatedResponse.candidates[i].finishReason = | ||
response.candidates[i].finishReason; | ||
aggregatedResponse.candidates[i].finishMessage = | ||
response.candidates[i].finishMessage; | ||
aggregatedResponse.candidates[i].safetyRatings = | ||
response.candidates[i].safetyRatings; | ||
if ('parts' in response.candidates[i].content) { | ||
for (const part of response.candidates[i].content.parts) { | ||
if (part.text) { | ||
aggregatedResponse.candidates[i].content.parts[0].text += part.text; | ||
} | ||
if (part.functionCall) { | ||
aggregatedResponse.candidates[i].content.parts[0].functionCall = | ||
part.functionCall; | ||
// the empty 'text' key should be removed if functionCall is in the | ||
// response | ||
delete aggregatedResponse.candidates[i].content.parts[0].text; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
aggregatedResponse.promptFeedback = | ||
responses[responses.length - 1].promptFeedback; | ||
return aggregatedResponse; | ||
} | ||
/** | ||
* Process model responses from generateContent | ||
* @ignore | ||
*/ | ||
export async function processNonStream( | ||
response: any | ||
): Promise<GenerateContentResult> { | ||
if (response !== undefined) { | ||
// ts-ignore | ||
const responseJson = await response.json(); | ||
return Promise.resolve({ | ||
response: responseJson, | ||
}); | ||
} | ||
return Promise.resolve({ | ||
response: {candidates: []}, | ||
}); | ||
} | ||
/** | ||
* Process model responses from countTokens | ||
* @ignore | ||
*/ | ||
export function processCountTokenResponse(response: any): CountTokensResponse { | ||
export async function processCountTokenResponse( | ||
response: any | ||
): Promise<CountTokensResponse> { | ||
// ts-ignore | ||
const responseJson = response.json(); | ||
return responseJson as CountTokensResponse; | ||
return response.json(); | ||
} |
@@ -17,1 +17,76 @@ /** | ||
*/ | ||
import { | ||
Content, | ||
GenerateContentRequest, | ||
GenerationConfig, | ||
SafetySetting, | ||
} from '../types/content'; | ||
import {ClientError} from '../types/errors'; | ||
import * as constants from '../util/constants'; | ||
export function formatContentRequest( | ||
request: GenerateContentRequest | string, | ||
generation_config?: GenerationConfig, | ||
safety_settings?: SafetySetting[] | ||
): GenerateContentRequest { | ||
if (typeof request === 'string') { | ||
return { | ||
contents: [{role: constants.USER_ROLE, parts: [{text: request}]}], | ||
generation_config: generation_config, | ||
safety_settings: safety_settings, | ||
}; | ||
} else { | ||
return request; | ||
} | ||
} | ||
export function validateGenerateContentRequest( | ||
request: GenerateContentRequest | ||
) { | ||
validateGcsInput(request.contents); | ||
validateFunctionResponseRequest(request.contents); | ||
} | ||
export function validateGenerationConfig( | ||
generation_config: GenerationConfig | ||
): GenerationConfig { | ||
if ('top_k' in generation_config) { | ||
if (!(generation_config.top_k! > 0) || !(generation_config.top_k! <= 40)) { | ||
delete generation_config.top_k; | ||
} | ||
} | ||
return generation_config; | ||
} | ||
function validateGcsInput(contents: Content[]) { | ||
for (const content of contents) { | ||
for (const part of content.parts) { | ||
if ('file_data' in part) { | ||
// @ts-ignore | ||
const uri = part['file_data']['file_uri']; | ||
if (!uri.startsWith('gs://')) { | ||
throw new URIError( | ||
`Found invalid Google Cloud Storage URI ${uri}, Google Cloud Storage URIs must start with gs://` | ||
); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
function validateFunctionResponseRequest(contents: Content[]) { | ||
const lastestContentPart = contents[contents.length - 1].parts[0]; | ||
if (!('functionResponse' in lastestContentPart)) { | ||
return; | ||
} | ||
const errorMessage = | ||
'Please ensure that function response turn comes immediately after a function call turn.'; | ||
if (contents.length < 2) { | ||
throw new ClientError(errorMessage); | ||
} | ||
const secondLastestContentPart = contents[contents.length - 2].parts[0]; | ||
if (!('functionCall' in secondLastestContentPart)) { | ||
throw new ClientError(errorMessage); | ||
} | ||
} |
@@ -18,5 +18,21 @@ /** | ||
import { | ||
CountTokensRequest, | ||
FinishReason, | ||
FunctionDeclarationSchemaType, | ||
GenerateContentRequest, | ||
GenerateContentResponse, | ||
GenerateContentResult, | ||
HarmBlockThreshold, | ||
HarmCategory, | ||
HarmProbability, | ||
SafetyRating, | ||
SafetySetting, | ||
StreamGenerateContentResult, | ||
Tool, | ||
} from '../../types'; | ||
import {constants} from '../../util'; | ||
import {countTokens} from '../count_tokens'; | ||
import {CountTokensRequest} from '../../types'; | ||
import {constants} from '../../util'; | ||
import {generateContent, generateContentStream} from '../generate_content'; | ||
import * as StreamFunctions from '../post_fetch_processing'; | ||
@@ -26,8 +42,191 @@ const TEST_PROJECT = 'test-project'; | ||
const TEST_PUBLISHER_MODEL_ENDPOINT = 'test-publisher-model-endpoint'; | ||
const TEST_TOKEN_PROMISE = Promise.resolve('test-token'); | ||
const TEST_TOKEN = 'testtoken'; | ||
const TEST_TOKEN_PROMISE = Promise.resolve(TEST_TOKEN); | ||
const TEST_API_ENDPOINT = 'test-api-endpoint'; | ||
const TEST_CHAT_MESSSAGE_TEXT = 'How are you doing today?'; | ||
const TEST_CHAT_MESSAGE_TEXT = 'How are you doing today?'; | ||
const TEST_USER_CHAT_MESSAGE = [ | ||
{role: constants.USER_ROLE, parts: [{text: TEST_CHAT_MESSSAGE_TEXT}]}, | ||
{role: constants.USER_ROLE, parts: [{text: TEST_CHAT_MESSAGE_TEXT}]}, | ||
]; | ||
const TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE = [ | ||
{ | ||
role: constants.USER_ROLE, | ||
parts: [ | ||
{text: TEST_CHAT_MESSAGE_TEXT}, | ||
{ | ||
file_data: { | ||
file_uri: 'gs://test_bucket/test_image.jpeg', | ||
mime_type: 'image/jpeg', | ||
}, | ||
}, | ||
], | ||
}, | ||
]; | ||
const TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE = [ | ||
{ | ||
role: constants.USER_ROLE, | ||
parts: [ | ||
{text: TEST_CHAT_MESSAGE_TEXT}, | ||
{file_data: {file_uri: 'test_image.jpeg', mime_type: 'image/jpeg'}}, | ||
], | ||
}, | ||
]; | ||
const TEST_SAFETY_SETTINGS: SafetySetting[] = [ | ||
{ | ||
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, | ||
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH, | ||
}, | ||
]; | ||
const TEST_SAFETY_RATINGS: SafetyRating[] = [ | ||
{ | ||
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, | ||
probability: HarmProbability.NEGLIGIBLE, | ||
}, | ||
]; | ||
const TEST_GENERATION_CONFIG = { | ||
candidate_count: 1, | ||
stop_sequences: ['hello'], | ||
}; | ||
const TEST_CANDIDATES = [ | ||
{ | ||
index: 1, | ||
content: { | ||
role: constants.MODEL_ROLE, | ||
parts: [{text: 'Im doing great! How are you?'}], | ||
}, | ||
finishReason: FinishReason.STOP, | ||
finishMessage: '', | ||
safetyRatings: TEST_SAFETY_RATINGS, | ||
citationMetadata: { | ||
citationSources: [ | ||
{ | ||
startIndex: 367, | ||
endIndex: 491, | ||
uri: 'https://www.numerade.com/ask/question/why-does-the-uncertainty-principle-make-it-impossible-to-predict-a-trajectory-for-the-clectron-95172/', | ||
}, | ||
], | ||
}, | ||
}, | ||
]; | ||
const TEST_MODEL_RESPONSE = { | ||
candidates: TEST_CANDIDATES, | ||
usage_metadata: {prompt_token_count: 0, candidates_token_count: 0}, | ||
}; | ||
const TEST_FUNCTION_CALL_RESPONSE = { | ||
functionCall: { | ||
name: 'get_current_weather', | ||
args: { | ||
location: 'LA', | ||
unit: 'fahrenheit', | ||
}, | ||
}, | ||
}; | ||
const TEST_CANDIDATES_WITH_FUNCTION_CALL = [ | ||
{ | ||
index: 1, | ||
content: { | ||
role: constants.MODEL_ROLE, | ||
parts: [TEST_FUNCTION_CALL_RESPONSE], | ||
}, | ||
finishReason: FinishReason.STOP, | ||
finishMessage: '', | ||
safetyRatings: TEST_SAFETY_RATINGS, | ||
}, | ||
]; | ||
const TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL = { | ||
candidates: TEST_CANDIDATES_WITH_FUNCTION_CALL, | ||
}; | ||
const TEST_FUNCTION_RESPONSE_PART = [ | ||
{ | ||
functionResponse: { | ||
name: 'get_current_weather', | ||
response: {name: 'get_current_weather', content: {weather: 'super nice'}}, | ||
}, | ||
}, | ||
]; | ||
const TEST_CANDIDATES_MISSING_ROLE = [ | ||
{ | ||
index: 1, | ||
content: {parts: [{text: 'Im doing great! How are you?'}]}, | ||
finish_reason: 0, | ||
finish_message: '', | ||
safety_ratings: TEST_SAFETY_RATINGS, | ||
}, | ||
]; | ||
const TEST_ENDPOINT_BASE_PATH = 'test.googleapis.com'; | ||
const TEST_GCS_FILENAME = 'gs://test_bucket/test_image.jpeg'; | ||
const TEST_MULTIPART_MESSAGE = [ | ||
{ | ||
role: constants.USER_ROLE, | ||
parts: [ | ||
{text: 'What is in this picture?'}, | ||
{file_data: {file_uri: TEST_GCS_FILENAME, mime_type: 'image/jpeg'}}, | ||
], | ||
}, | ||
]; | ||
const BASE_64_IMAGE = | ||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='; | ||
const INLINE_DATA_FILE_PART = { | ||
inline_data: { | ||
data: BASE_64_IMAGE, | ||
mime_type: 'image/jpeg', | ||
}, | ||
}; | ||
const TEST_MULTIPART_MESSAGE_BASE64 = [ | ||
{ | ||
role: constants.USER_ROLE, | ||
parts: [{text: 'What is in this picture?'}, INLINE_DATA_FILE_PART], | ||
}, | ||
]; | ||
const TEST_TOOLS_WITH_FUNCTION_DECLARATION: Tool[] = [ | ||
{ | ||
function_declarations: [ | ||
{ | ||
name: 'get_current_weather', | ||
description: 'get weather in a given location', | ||
parameters: { | ||
type: FunctionDeclarationSchemaType.OBJECT, | ||
properties: { | ||
location: {type: FunctionDeclarationSchemaType.STRING}, | ||
unit: { | ||
type: FunctionDeclarationSchemaType.STRING, | ||
enum: ['celsius', 'fahrenheit'], | ||
}, | ||
}, | ||
required: ['location'], | ||
}, | ||
}, | ||
], | ||
}, | ||
]; | ||
const fetchResponseObj = { | ||
status: 200, | ||
statusText: 'OK', | ||
ok: true, | ||
headers: {'Content-Type': 'application/json'}, | ||
url: 'url', | ||
}; | ||
/** | ||
* Returns a generator, used to mock the generateContentStream response | ||
* @ignore | ||
*/ | ||
export async function* testGenerator(): AsyncGenerator<GenerateContentResponse> { | ||
yield { | ||
candidates: TEST_CANDIDATES, | ||
}; | ||
} | ||
describe('countTokens', () => { | ||
@@ -39,9 +238,2 @@ const req: CountTokensRequest = { | ||
it('return expected response when OK', async () => { | ||
const fetchResponseObj = { | ||
status: 200, | ||
statusText: 'OK', | ||
ok: true, | ||
headers: {'Content-Type': 'application/json'}, | ||
url: 'url', | ||
}; | ||
const expectedResponseBody = { | ||
@@ -61,4 +253,4 @@ totalTokens: 1, | ||
TEST_TOKEN_PROMISE, | ||
TEST_API_ENDPOINT, | ||
req | ||
req, | ||
TEST_API_ENDPOINT | ||
); | ||
@@ -75,6 +267,10 @@ | ||
}; | ||
const body = {}; | ||
const body = { | ||
code: 500, | ||
message: 'service is having downtime', | ||
status: 'INTERNAL_SERVER_ERROR', | ||
}; | ||
const response = new Response(JSON.stringify(body), fetch500Obj); | ||
const expectedErrorMessage = | ||
'[VertexAI.GoogleGenerativeAIError]: got status: 500 Internal Server Error'; | ||
'[VertexAI.GoogleGenerativeAIError]: got status: 500 Internal Server Error. {"code":500,"message":"service is having downtime","status":"INTERNAL_SERVER_ERROR"}'; | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
@@ -88,16 +284,17 @@ | ||
TEST_TOKEN_PROMISE, | ||
TEST_API_ENDPOINT, | ||
req | ||
req, | ||
TEST_API_ENDPOINT | ||
) | ||
).toBeRejected(); | ||
await countTokens( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
TEST_API_ENDPOINT, | ||
req | ||
).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
// TODO: update jasmine version or use flush to uncomment | ||
// await countTokens( | ||
// TEST_LOCATION, | ||
// TEST_PROJECT, | ||
// TEST_PUBLISHER_MODEL_ENDPOINT, | ||
// TEST_TOKEN_PROMISE, | ||
// req, | ||
// TEST_API_ENDPOINT | ||
// ).catch(e => { | ||
// expect(e.message).toEqual(expectedErrorMessage); | ||
// }); | ||
}); | ||
@@ -111,6 +308,10 @@ | ||
}; | ||
const body = {}; | ||
const body = { | ||
code: 400, | ||
message: 'request is invalid', | ||
status: 'INVALID_ARGUMENT', | ||
}; | ||
const response = new Response(JSON.stringify(body), fetch400Obj); | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: got status: 400 Bad Request'; | ||
'[VertexAI.ClientError]: got status: 400 Bad Request. {"code":400,"message":"request is invalid","status":"INVALID_ARGUMENT"}'; | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
@@ -124,7 +325,45 @@ | ||
TEST_TOKEN_PROMISE, | ||
TEST_API_ENDPOINT, | ||
req | ||
req, | ||
TEST_API_ENDPOINT | ||
) | ||
).toBeRejected(); | ||
await countTokens( | ||
// TODO: update jasmine version or use flush to uncomment | ||
// await countTokens( | ||
// TEST_LOCATION, | ||
// TEST_PROJECT, | ||
// TEST_PUBLISHER_MODEL_ENDPOINT, | ||
// TEST_TOKEN_PROMISE, | ||
// req, | ||
// TEST_API_ENDPOINT | ||
// ).catch(e => { | ||
// expect(e.message).toEqual(expectedErrorMessage); | ||
// }); | ||
}); | ||
}); | ||
describe('generateContent', () => { | ||
let expectedStreamResult: StreamGenerateContentResult; | ||
let fetchSpy: jasmine.Spy; | ||
beforeEach(() => { | ||
expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
const fetchResult = new Response( | ||
JSON.stringify(expectedStreamResult), | ||
fetchResponseObj | ||
); | ||
fetchSpy = spyOn(global, 'fetch').and.resolveTo(fetchResult); | ||
}); | ||
it('returns a GenerateContentResponse', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await generateContent( | ||
TEST_LOCATION, | ||
@@ -134,4 +373,386 @@ TEST_PROJECT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a string', async () => { | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
TEST_CHAT_MESSAGE_TEXT, | ||
TEST_API_ENDPOINT | ||
); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a GCS URI', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('raises an error when passed an invalid GCS URI', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE, | ||
}; | ||
await expectAsync( | ||
generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
) | ||
).toBeRejectedWithError(URIError); | ||
}); | ||
it('returns a GenerateContentResponse when passed safety_settings and generation_config', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
safety_settings: TEST_SAFETY_SETTINGS, | ||
generation_config: TEST_GENERATION_CONFIG, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('updates the base API endpoint when provided', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
await generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_ENDPOINT_BASE_PATH | ||
); | ||
expect(fetchSpy.calls.allArgs()[0][0].toString()).toContain( | ||
TEST_ENDPOINT_BASE_PATH | ||
); | ||
}); | ||
it('removes top_k when it is set to 0', async () => { | ||
const reqWithEmptyConfigs: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
generation_config: {top_k: 0}, | ||
safety_settings: [], | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
await generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
reqWithEmptyConfigs, | ||
TEST_API_ENDPOINT | ||
); | ||
const requestArgs = fetchSpy.calls.allArgs()[0][1]; | ||
if (typeof requestArgs === 'object' && requestArgs) { | ||
expect(JSON.stringify(requestArgs['body'])).not.toContain('top_k'); | ||
} | ||
}); | ||
it('includes top_k when it is within 1 - 40', async () => { | ||
const reqWithEmptyConfigs: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
generation_config: {top_k: 1}, | ||
safety_settings: [], | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
await generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
reqWithEmptyConfigs, | ||
TEST_API_ENDPOINT | ||
); | ||
const requestArgs = fetchSpy.calls.allArgs()[0][1]; | ||
if (typeof requestArgs === 'object' && requestArgs) { | ||
expect(JSON.stringify(requestArgs['body'])).toContain('top_k'); | ||
} | ||
}); | ||
it('aggregates citation metadata', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
); | ||
expect( | ||
resp.response.candidates[0].citationMetadata?.citationSources.length | ||
).toEqual( | ||
TEST_MODEL_RESPONSE.candidates[0].citationMetadata.citationSources.length | ||
); | ||
}); | ||
it('returns a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.resolveTo(expectedResult); | ||
const resp = await generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{text: 'What is the weater like in Boston?'}], | ||
}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{text: 'What is the weater like in Boston?'}], | ||
}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await generateContent( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('generateContentStream', () => { | ||
let expectedStreamResult: StreamGenerateContentResult; | ||
let fetchSpy: jasmine.Spy; | ||
beforeEach(() => { | ||
expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
const fetchResult = new Response( | ||
JSON.stringify(expectedStreamResult), | ||
fetchResponseObj | ||
); | ||
fetchSpy = spyOn(global, 'fetch').and.resolveTo(fetchResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed text content', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await generateContentStream( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a string', async () => { | ||
const expectedResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await generateContentStream( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
TEST_API_ENDPOINT, | ||
req | ||
TEST_CHAT_MESSAGE_TEXT | ||
); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed multi-part content with a GCS URI', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_MULTIPART_MESSAGE, | ||
}; | ||
const expectedResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await generateContentStream( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed multi-part content with base64 data', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_MULTIPART_MESSAGE_BASE64, | ||
}; | ||
const expectedResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await generateContentStream( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedStreamResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedStreamResult); | ||
const resp = await generateContentStream( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
); | ||
expect(resp).toEqual(expectedStreamResult); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{text: 'What is the weater like in Boston?'}], | ||
}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await generateContentStream( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
).catch(e => { | ||
@@ -141,2 +762,30 @@ expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{text: 'What is the weater like in Boston?'}], | ||
}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await generateContentStream( | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_TOKEN_PROMISE, | ||
req, | ||
TEST_API_ENDPOINT | ||
).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); |
660
src/index.ts
@@ -18,661 +18,3 @@ /** | ||
/* tslint:disable */ | ||
import {GoogleAuth, GoogleAuthOptions} from 'google-auth-library'; | ||
import { | ||
processCountTokenResponse, | ||
processNonStream, | ||
processStream, | ||
} from './process_stream'; | ||
import { | ||
Content, | ||
CountTokensRequest, | ||
CountTokensResponse, | ||
GenerateContentRequest, | ||
GenerateContentResult, | ||
GenerationConfig, | ||
ModelParams, | ||
Part, | ||
SafetySetting, | ||
StreamGenerateContentResult, | ||
Tool, | ||
VertexInit, | ||
} from './types/content'; | ||
import { | ||
ClientError, | ||
GoogleAuthError, | ||
GoogleGenerativeAIError, | ||
} from './types/errors'; | ||
import {constants, postRequest} from './util'; | ||
export {VertexAI} from './vertex_ai'; | ||
export * from './types'; | ||
/** | ||
* Base class for authenticating to Vertex, creates the preview namespace. | ||
*/ | ||
export class VertexAI { | ||
public preview: VertexAI_Preview; | ||
/** | ||
* @constructor | ||
* @param {VertexInit} init - assign authentication related information, | ||
* including project and location string, to instantiate a Vertex AI | ||
* client. | ||
*/ | ||
constructor(init: VertexInit) { | ||
/** | ||
* preview property is used to access any SDK methods available in public | ||
* preview, currently all functionality. | ||
*/ | ||
this.preview = new VertexAI_Preview( | ||
init.project, | ||
init.location, | ||
init.apiEndpoint, | ||
init.googleAuthOptions | ||
); | ||
} | ||
} | ||
/** | ||
* VertexAI class internal implementation for authentication. | ||
*/ | ||
export class VertexAI_Preview { | ||
protected googleAuth: GoogleAuth; | ||
/** | ||
* @constructor | ||
* @param {string} - project The Google Cloud project to use for the request | ||
* @param {string} - location The Google Cloud project location to use for the request | ||
* @param {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
* @param {GoogleAuthOptions} - [googleAuthOptions] The Authentication options provided by google-auth-library. | ||
* Complete list of authentication options is documented in the GoogleAuthOptions interface: | ||
* https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
*/ | ||
constructor( | ||
readonly project: string, | ||
readonly location: string, | ||
readonly apiEndpoint?: string, | ||
readonly googleAuthOptions?: GoogleAuthOptions | ||
) { | ||
const opts = this.validateGoogleAuthOptions(project, googleAuthOptions); | ||
this.project = project; | ||
this.location = location; | ||
this.apiEndpoint = apiEndpoint; | ||
this.googleAuth = new GoogleAuth(opts); | ||
} | ||
/** | ||
* @param {ModelParams} modelParams - {@link ModelParams} Parameters to specify the generative model. | ||
* @return {GenerativeModel} Instance of the GenerativeModel class. {@link GenerativeModel} | ||
*/ | ||
getGenerativeModel(modelParams: ModelParams): GenerativeModel { | ||
const getGenerativeModelParams: GetGenerativeModelParams = { | ||
model: modelParams.model, | ||
project: this.project, | ||
location: this.location, | ||
googleAuth: this.googleAuth, | ||
apiEndpoint: this.apiEndpoint, | ||
safety_settings: modelParams.safety_settings, | ||
tools: modelParams.tools, | ||
}; | ||
if (modelParams.generation_config) { | ||
getGenerativeModelParams.generation_config = validateGenerationConfig( | ||
modelParams.generation_config | ||
); | ||
} | ||
return new GenerativeModel(getGenerativeModelParams); | ||
} | ||
validateGoogleAuthOptions( | ||
project: string, | ||
googleAuthOptions?: GoogleAuthOptions | ||
): GoogleAuthOptions { | ||
let opts: GoogleAuthOptions; | ||
const requiredScope = 'https://www.googleapis.com/auth/cloud-platform'; | ||
if (!googleAuthOptions) { | ||
opts = { | ||
scopes: requiredScope, | ||
}; | ||
return opts; | ||
} | ||
if ( | ||
googleAuthOptions.projectId && | ||
googleAuthOptions.projectId !== project | ||
) { | ||
throw new Error( | ||
`inconsistent project ID values. argument project got value ${project} but googleAuthOptions.projectId got value ${googleAuthOptions.projectId}` | ||
); | ||
} | ||
opts = googleAuthOptions; | ||
if (!opts.scopes) { | ||
opts.scopes = requiredScope; | ||
return opts; | ||
} | ||
if ( | ||
(typeof opts.scopes === 'string' && opts.scopes !== requiredScope) || | ||
(Array.isArray(opts.scopes) && opts.scopes.indexOf(requiredScope) < 0) | ||
) { | ||
throw new GoogleAuthError( | ||
`input GoogleAuthOptions.scopes ${opts.scopes} doesn't contain required scope ${requiredScope}, please include ${requiredScope} into GoogleAuthOptions.scopes or leave GoogleAuthOptions.scopes undefined` | ||
); | ||
} | ||
return opts; | ||
} | ||
} | ||
/** | ||
* Params to initiate a multiturn chat with the model via startChat | ||
* @property {Content[]} - [history] history of the chat session. {@link Content} | ||
* @property {SafetySetting[]} - [safety_settings] Array of {@link SafetySetting} | ||
* @property {GenerationConfig} - [generation_config] {@link GenerationConfig} | ||
*/ | ||
export declare interface StartChatParams { | ||
history?: Content[]; | ||
safety_settings?: SafetySetting[]; | ||
generation_config?: GenerationConfig; | ||
tools?: Tool[]; | ||
} | ||
// StartChatSessionRequest and ChatSession are defined here instead of in | ||
// src/types to avoid a circular dependency issue due the dep on | ||
// VertexAI_Preview | ||
/** | ||
* All params passed to initiate multiturn chat via startChat | ||
* @property {VertexAI_Preview} - _vertex_instance {@link VertexAI_Preview} | ||
* @property {GenerativeModel} - _model_instance {@link GenerativeModel} | ||
*/ | ||
export declare interface StartChatSessionRequest extends StartChatParams { | ||
project: string; | ||
location: string; | ||
_model_instance: GenerativeModel; | ||
} | ||
/** | ||
* @property {string} model - model name | ||
* @property {string} project - project The Google Cloud project to use for the request | ||
* @property {string} location - The Google Cloud project location to use for the request | ||
* @property {GoogleAuth} googleAuth - GoogleAuth class instance that handles authentication. | ||
* Details about GoogleAuth is referred to https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
* @property {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
* @property {GenerationConfig} [generation_config] - {@link | ||
* GenerationConfig} | ||
* @property {SafetySetting[]} [safety_settings] - {@link SafetySetting} | ||
* @property {Tool[]} [tools] - {@link Tool} | ||
*/ | ||
export declare interface GetGenerativeModelParams extends ModelParams { | ||
model: string; | ||
project: string; | ||
location: string; | ||
googleAuth: GoogleAuth; | ||
apiEndpoint?: string; | ||
generation_config?: GenerationConfig; | ||
safety_settings?: SafetySetting[]; | ||
tools?: Tool[]; | ||
} | ||
/** | ||
* Chat session to make multi-turn send message request. | ||
* `sendMessage` method makes async call to get response of a chat message. | ||
* `sendMessageStream` method makes async call to stream response of a chat message. | ||
*/ | ||
export class ChatSession { | ||
private project: string; | ||
private location: string; | ||
private historyInternal: Content[]; | ||
private _model_instance: GenerativeModel; | ||
private _send_stream_promise: Promise<void> = Promise.resolve(); | ||
generation_config?: GenerationConfig; | ||
safety_settings?: SafetySetting[]; | ||
tools?: Tool[]; | ||
get history(): Content[] { | ||
return this.historyInternal; | ||
} | ||
/** | ||
* @constructor | ||
* @param {StartChatSessionRequest} request - {@link StartChatSessionRequest} | ||
*/ | ||
constructor(request: StartChatSessionRequest) { | ||
this.project = request.project; | ||
this.location = request.location; | ||
this._model_instance = request._model_instance; | ||
this.historyInternal = request.history ?? []; | ||
this.generation_config = request.generation_config; | ||
this.safety_settings = request.safety_settings; | ||
this.tools = request.tools; | ||
} | ||
/** | ||
* Make an sync call to send message. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<GenerateContentResult>} Promise of {@link GenerateContentResult} | ||
*/ | ||
async sendMessage( | ||
request: string | Array<string | Part> | ||
): Promise<GenerateContentResult> { | ||
const newContent: Content[] = | ||
formulateNewContentFromSendMessageRequest(request); | ||
const generateContentrequest: GenerateContentRequest = { | ||
contents: this.historyInternal.concat(newContent), | ||
safety_settings: this.safety_settings, | ||
generation_config: this.generation_config, | ||
tools: this.tools, | ||
}; | ||
const generateContentResult: GenerateContentResult = | ||
await this._model_instance | ||
.generateContent(generateContentrequest) | ||
.catch(e => { | ||
throw e; | ||
}); | ||
const generateContentResponse = await generateContentResult.response; | ||
// Only push the latest message to history if the response returned a result | ||
if (generateContentResponse.candidates.length !== 0) { | ||
this.historyInternal = this.historyInternal.concat(newContent); | ||
const contentFromAssistant = | ||
generateContentResponse.candidates[0].content; | ||
if (!contentFromAssistant.role) { | ||
contentFromAssistant.role = constants.MODEL_ROLE; | ||
} | ||
this.historyInternal.push(contentFromAssistant); | ||
} else { | ||
// TODO: handle promptFeedback in the response | ||
throw new Error('Did not get a candidate from the model'); | ||
} | ||
return Promise.resolve({response: generateContentResponse}); | ||
} | ||
async appendHistory( | ||
streamGenerateContentResultPromise: Promise<StreamGenerateContentResult>, | ||
newContent: Content[] | ||
): Promise<void> { | ||
const streamGenerateContentResult = | ||
await streamGenerateContentResultPromise; | ||
const streamGenerateContentResponse = | ||
await streamGenerateContentResult.response; | ||
// Only push the latest message to history if the response returned a result | ||
if (streamGenerateContentResponse.candidates.length !== 0) { | ||
this.historyInternal = this.historyInternal.concat(newContent); | ||
const contentFromAssistant = | ||
streamGenerateContentResponse.candidates[0].content; | ||
if (!contentFromAssistant.role) { | ||
contentFromAssistant.role = constants.MODEL_ROLE; | ||
} | ||
this.historyInternal.push(contentFromAssistant); | ||
} else { | ||
// TODO: handle promptFeedback in the response | ||
throw new Error('Did not get a candidate from the model'); | ||
} | ||
} | ||
/** | ||
* Make an async call to stream send message. Response will be returned in stream. | ||
* @param {string | Array<string | Part>} request - send message request. {@link Part} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
async sendMessageStream( | ||
request: string | Array<string | Part> | ||
): Promise<StreamGenerateContentResult> { | ||
const newContent: Content[] = | ||
formulateNewContentFromSendMessageRequest(request); | ||
const generateContentrequest: GenerateContentRequest = { | ||
contents: this.historyInternal.concat(newContent), | ||
safety_settings: this.safety_settings, | ||
generation_config: this.generation_config, | ||
tools: this.tools, | ||
}; | ||
const streamGenerateContentResultPromise = this._model_instance | ||
.generateContentStream(generateContentrequest) | ||
.catch(e => { | ||
throw e; | ||
}); | ||
this._send_stream_promise = this.appendHistory( | ||
streamGenerateContentResultPromise, | ||
newContent | ||
).catch(e => { | ||
throw new GoogleGenerativeAIError('exception appending chat history', e); | ||
}); | ||
return streamGenerateContentResultPromise; | ||
} | ||
} | ||
/** | ||
* Base class for generative models. | ||
* NOTE: this class should not be instantiated directly. Use | ||
* `vertexai.preview.getGenerativeModel()` instead. | ||
*/ | ||
export class GenerativeModel { | ||
model: string; | ||
generation_config?: GenerationConfig; | ||
safety_settings?: SafetySetting[]; | ||
tools?: Tool[]; | ||
private project: string; | ||
private location: string; | ||
private googleAuth: GoogleAuth; | ||
private publisherModelEndpoint: string; | ||
private apiEndpoint?: string; | ||
/** | ||
* @constructor | ||
* @param {GetGenerativeModelParams} getGenerativeModelParams - {@link GetGenerativeModelParams} | ||
*/ | ||
constructor(getGenerativeModelParams: GetGenerativeModelParams) { | ||
this.project = getGenerativeModelParams.project; | ||
this.location = getGenerativeModelParams.location; | ||
this.apiEndpoint = getGenerativeModelParams.apiEndpoint; | ||
this.googleAuth = getGenerativeModelParams.googleAuth; | ||
this.model = getGenerativeModelParams.model; | ||
this.generation_config = getGenerativeModelParams.generation_config; | ||
this.safety_settings = getGenerativeModelParams.safety_settings; | ||
this.tools = getGenerativeModelParams.tools; | ||
if (this.model.startsWith('models/')) { | ||
this.publisherModelEndpoint = `publishers/google/${this.model}`; | ||
} else { | ||
this.publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
} | ||
} | ||
/** | ||
* Get access token from GoogleAuth. Throws GoogleAuthError when fails. | ||
* @return {Promise<any>} Promise of token | ||
*/ | ||
get token(): Promise<any> { | ||
const credential_error_message = | ||
'\nUnable to authenticate your request\ | ||
\nDepending on your run time environment, you can get authentication by\ | ||
\n- if in local instance or cloud shell: `!gcloud auth login`\ | ||
\n- if in Colab:\ | ||
\n -`from google.colab import auth`\ | ||
\n -`auth.authenticate_user()`\ | ||
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication'; | ||
const tokenPromise = this.googleAuth.getAccessToken().catch(e => { | ||
throw new GoogleAuthError(credential_error_message, e); | ||
}); | ||
return tokenPromise; | ||
} | ||
/** | ||
* Make a async call to generate content. | ||
* @param request A GenerateContentRequest object with the request contents. | ||
* @return The GenerateContentResponse object with the response candidates. | ||
*/ | ||
async generateContent( | ||
request: GenerateContentRequest | string | ||
): Promise<GenerateContentResult> { | ||
request = formatContentRequest( | ||
request, | ||
this.generation_config, | ||
this.safety_settings | ||
); | ||
validateGenerateContentRequest(request); | ||
if (request.generation_config) { | ||
request.generation_config = validateGenerationConfig( | ||
request.generation_config | ||
); | ||
} | ||
const generateContentRequest: GenerateContentRequest = { | ||
contents: request.contents, | ||
generation_config: request.generation_config ?? this.generation_config, | ||
safety_settings: request.safety_settings ?? this.safety_settings, | ||
tools: request.tools ?? [], | ||
}; | ||
const response: Response | undefined = await postRequest({ | ||
region: this.location, | ||
project: this.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: constants.GENERATE_CONTENT_METHOD, | ||
token: await this.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this.apiEndpoint, | ||
}).catch(e => { | ||
throw new GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
const result: GenerateContentResult = processNonStream(response); | ||
return Promise.resolve(result); | ||
} | ||
/** | ||
* Make an async stream request to generate content. The response will be returned in stream. | ||
* @param {GenerateContentRequest} request - {@link GenerateContentRequest} | ||
* @return {Promise<StreamGenerateContentResult>} Promise of {@link StreamGenerateContentResult} | ||
*/ | ||
async generateContentStream( | ||
request: GenerateContentRequest | string | ||
): Promise<StreamGenerateContentResult> { | ||
request = formatContentRequest( | ||
request, | ||
this.generation_config, | ||
this.safety_settings | ||
); | ||
validateGenerateContentRequest(request); | ||
if (request.generation_config) { | ||
request.generation_config = validateGenerationConfig( | ||
request.generation_config | ||
); | ||
} | ||
const generateContentRequest: GenerateContentRequest = { | ||
contents: request.contents, | ||
generation_config: request.generation_config ?? this.generation_config, | ||
safety_settings: request.safety_settings ?? this.safety_settings, | ||
tools: request.tools ?? [], | ||
}; | ||
const response = await postRequest({ | ||
region: this.location, | ||
project: this.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD, | ||
token: await this.token, | ||
data: generateContentRequest, | ||
apiEndpoint: this.apiEndpoint, | ||
}).catch(e => { | ||
throw new GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
const streamResult = processStream(response); | ||
return Promise.resolve(streamResult); | ||
} | ||
/** | ||
* Make a async request to count tokens. | ||
* @param request A CountTokensRequest object with the request contents. | ||
* @return The CountTokensResponse object with the token count. | ||
*/ | ||
async countTokens(request: CountTokensRequest): Promise<CountTokensResponse> { | ||
const response = await postRequest({ | ||
region: this.location, | ||
project: this.project, | ||
resourcePath: this.publisherModelEndpoint, | ||
resourceMethod: 'countTokens', | ||
token: await this.token, | ||
data: request, | ||
apiEndpoint: this.apiEndpoint, | ||
}).catch(e => { | ||
throw new GoogleGenerativeAIError('exception posting request', e); | ||
}); | ||
throwErrorIfNotOK(response); | ||
return processCountTokenResponse(response); | ||
} | ||
/** | ||
* Instantiate a ChatSession. | ||
* This method doesn't make any call to remote endpoint. | ||
* Any call to remote endpoint is implemented in ChatSession class @see ChatSession | ||
* @param{StartChatParams} [request] - {@link StartChatParams} | ||
* @return {ChatSession} {@link ChatSession} | ||
*/ | ||
startChat(request?: StartChatParams): ChatSession { | ||
const startChatRequest: StartChatSessionRequest = { | ||
project: this.project, | ||
location: this.location, | ||
_model_instance: this, | ||
}; | ||
if (request) { | ||
startChatRequest.history = request.history; | ||
startChatRequest.generation_config = | ||
request.generation_config ?? this.generation_config; | ||
startChatRequest.safety_settings = | ||
request.safety_settings ?? this.safety_settings; | ||
startChatRequest.tools = request.tools ?? this.tools; | ||
} | ||
return new ChatSession(startChatRequest); | ||
} | ||
} | ||
function formulateNewContentFromSendMessageRequest( | ||
request: string | Array<string | Part> | ||
): Content[] { | ||
let newParts: Part[] = []; | ||
if (typeof request === 'string') { | ||
newParts = [{text: request}]; | ||
} else if (Array.isArray(request)) { | ||
for (const item of request) { | ||
if (typeof item === 'string') { | ||
newParts.push({text: item}); | ||
} else { | ||
newParts.push(item); | ||
} | ||
} | ||
} | ||
return assignRoleToPartsAndValidateSendMessageRequest(newParts); | ||
} | ||
/** | ||
* When multiple Part types (i.e. FunctionResponsePart and TextPart) are | ||
* passed in a single Part array, we may need to assign different roles to each | ||
* part. Currently only FunctionResponsePart requires a role other than 'user'. | ||
* @ignore | ||
* @param {Array<Part>} parts Array of parts to pass to the model | ||
* @return {Content[]} Array of content items | ||
*/ | ||
function assignRoleToPartsAndValidateSendMessageRequest( | ||
parts: Array<Part> | ||
): Content[] { | ||
const userContent: Content = {role: constants.USER_ROLE, parts: []}; | ||
const functionContent: Content = {role: constants.FUNCTION_ROLE, parts: []}; | ||
let hasUserContent = false; | ||
let hasFunctionContent = false; | ||
for (const part of parts) { | ||
if ('functionResponse' in part) { | ||
functionContent.parts.push(part); | ||
hasFunctionContent = true; | ||
} else { | ||
userContent.parts.push(part); | ||
hasUserContent = true; | ||
} | ||
} | ||
if (hasUserContent && hasFunctionContent) { | ||
throw new ClientError( | ||
'Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.' | ||
); | ||
} | ||
if (!hasUserContent && !hasFunctionContent) { | ||
throw new ClientError('No content is provided for sending chat message.'); | ||
} | ||
if (hasUserContent) { | ||
return [userContent]; | ||
} | ||
return [functionContent]; | ||
} | ||
function throwErrorIfNotOK(response: Response | undefined) { | ||
if (response === undefined) { | ||
throw new GoogleGenerativeAIError('response is undefined'); | ||
} | ||
const status: number = response.status; | ||
const statusText: string = response.statusText; | ||
const errorMessage = `got status: ${status} ${statusText}`; | ||
if (status >= 400 && status < 500) { | ||
throw new ClientError(errorMessage); | ||
} | ||
if (!response.ok) { | ||
throw new GoogleGenerativeAIError(errorMessage); | ||
} | ||
} | ||
function validateGcsInput(contents: Content[]) { | ||
for (const content of contents) { | ||
for (const part of content.parts) { | ||
if ('file_data' in part) { | ||
// @ts-ignore | ||
const uri = part['file_data']['file_uri']; | ||
if (!uri.startsWith('gs://')) { | ||
throw new URIError( | ||
`Found invalid Google Cloud Storage URI ${uri}, Google Cloud Storage URIs must start with gs://` | ||
); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
function validateFunctionResponseRequest(contents: Content[]) { | ||
const lastestContentPart = contents[contents.length - 1].parts[0]; | ||
if (!('functionResponse' in lastestContentPart)) { | ||
return; | ||
} | ||
const errorMessage = | ||
'Please ensure that function response turn comes immediately after a function call turn.'; | ||
if (contents.length < 2) { | ||
throw new ClientError(errorMessage); | ||
} | ||
const secondLastestContentPart = contents[contents.length - 2].parts[0]; | ||
if (!('functionCall' in secondLastestContentPart)) { | ||
throw new ClientError(errorMessage); | ||
} | ||
} | ||
function validateGenerateContentRequest(request: GenerateContentRequest) { | ||
validateGcsInput(request.contents); | ||
validateFunctionResponseRequest(request.contents); | ||
} | ||
function validateGenerationConfig( | ||
generation_config: GenerationConfig | ||
): GenerationConfig { | ||
if ('top_k' in generation_config) { | ||
if (!(generation_config.top_k! > 0) || !(generation_config.top_k! <= 40)) { | ||
delete generation_config.top_k; | ||
} | ||
} | ||
return generation_config; | ||
} | ||
function formatContentRequest( | ||
request: GenerateContentRequest | string, | ||
generation_config?: GenerationConfig, | ||
safety_settings?: SafetySetting[] | ||
): GenerateContentRequest { | ||
if (typeof request === 'string') { | ||
return { | ||
contents: [{role: constants.USER_ROLE, parts: [{text: request}]}], | ||
generation_config: generation_config, | ||
safety_settings: safety_settings, | ||
}; | ||
} else { | ||
return request; | ||
} | ||
} |
@@ -19,3 +19,3 @@ /** | ||
// @ts-nocheck | ||
import {GoogleAuthOptions} from 'google-auth-library'; | ||
import {GoogleAuth, GoogleAuthOptions} from 'google-auth-library'; | ||
@@ -66,2 +66,27 @@ /** | ||
/** | ||
* @property {string} model - model name | ||
* @property {string} project - project The Google Cloud project to use for the request | ||
* @property {string} location - The Google Cloud project location to use for the request | ||
* @property {GoogleAuth} googleAuth - GoogleAuth class instance that handles authentication. | ||
* Details about GoogleAuth is referred to https://github.com/googleapis/google-auth-library-nodejs/blob/main/src/auth/googleauth.ts | ||
* @property {string} - [apiEndpoint] The base Vertex AI endpoint to use for the request. If | ||
* not provided, the default regionalized endpoint | ||
* (i.e. us-central1-aiplatform.googleapis.com) will be used. | ||
* @property {GenerationConfig} [generation_config] - {@link | ||
* GenerationConfig} | ||
* @property {SafetySetting[]} [safety_settings] - {@link SafetySetting} | ||
* @property {Tool[]} [tools] - {@link Tool} | ||
*/ | ||
export declare interface GetGenerativeModelParams extends ModelParams { | ||
model: string; | ||
project: string; | ||
location: string; | ||
googleAuth: GoogleAuth; | ||
apiEndpoint?: string; | ||
generation_config?: GenerationConfig; | ||
safety_settings?: SafetySetting[]; | ||
tools?: Tool[]; | ||
} | ||
/** | ||
* Configuration for initializing a model, for example via getGenerativeModel | ||
@@ -467,3 +492,4 @@ * @property {string} model - model name. | ||
* @property {Content} - content. {@link Content} | ||
* @property {number} - [index]. The index of the candidate in the {@link GenerateContentResponse} | ||
* @property {number} - [index]. The index of the candidate in the {@link | ||
* GenerateContentResponse} | ||
* @property {FinishReason} - [finishReason]. {@link FinishReason} | ||
@@ -473,2 +499,4 @@ * @property {string} - [finishMessage]. | ||
* @property {CitationMetadata} - [citationMetadata]. {@link CitationMetadata} | ||
* @property {GroundingMetadata} - [groundingMetadata]. {@link | ||
* GroundingMetadata} | ||
*/ | ||
@@ -482,2 +510,3 @@ export declare interface GenerateContentCandidate { | ||
citationMetadata?: CitationMetadata; | ||
groundingMetadata?: GroundingMetadata; | ||
functionCall?: FunctionCall; | ||
@@ -509,2 +538,53 @@ } | ||
/** | ||
* A collection of grounding attributions for a piece of content. | ||
* @property {string[]} - [webSearchQueries]. Web search queries for the | ||
* following-up web search. | ||
* @property {GroundingAttribution[]} - [groundingAttributions]. Array of {@link | ||
* GroundingAttribution} | ||
*/ | ||
export declare interface GroundingMetadata { | ||
webSearchQueries?: string[]; | ||
groundingAttributions?: GroundingAttribution[]; | ||
} | ||
/** | ||
* Grounding attribution. | ||
* @property {GroundingAttributionWeb} - [web] Attribution from the web. | ||
* @property {GroundingAttributionSegment} - [segment] Segment of the content | ||
* this attribution belongs to. | ||
* @property {number} - [confidenceScore] Confidence score of the attribution. | ||
* Ranges from 0 to 1. 1 is the most confident. | ||
*/ | ||
export declare interface GroundingAttribution { | ||
web?: GroundingAttributionWeb; | ||
segment?: GroundingAttributionSegment; | ||
confidenceScore?: number; | ||
} | ||
/** | ||
* Segment of the content this attribution belongs to. | ||
* @property {number} - [part_index] The index of a Part object within its | ||
* parent Content object. | ||
* @property {number} - [startIndex] Start index in the given Part, measured in | ||
* bytes. Offset from the start of the Part, inclusive, starting at zero. | ||
* @property {number} - [endIndex] End index in the given Part, measured in | ||
* bytes. Offset from the start of the Part, exclusive, starting at zero. | ||
*/ | ||
export declare interface GroundingAttributionSegment { | ||
partIndex?: number; | ||
startIndex?: number; | ||
endIndex?: number; | ||
} | ||
/** | ||
* Attribution from the web. | ||
* @property {string} - [uri] URI reference of the attribution. | ||
* @property {string} - [title] Title of the attribution. | ||
*/ | ||
export declare interface GroundingAttributionWeb { | ||
uri?: string; | ||
title?: string; | ||
} | ||
/** | ||
* A predicted FunctionCall returned from the model that contains a string | ||
@@ -570,5 +650,5 @@ * representating the FunctionDeclaration.name with the parameters and their | ||
/** | ||
* A Tool is a piece of code that enables the system to interact with | ||
* external systems to perform an action, or set of actions, outside of | ||
* knowledge and scope of the model. | ||
* A FunctionDeclarationsTool is a piece of code that enables the system to | ||
* interact with external systems to perform an action, or set of actions, | ||
* outside of knowledge and scope of the model. | ||
* @property {object} - function_declarations One or more function declarations | ||
@@ -583,7 +663,53 @@ * to be passed to the model along with the current user query. Model may decide | ||
*/ | ||
export declare interface Tool { | ||
function_declarations: FunctionDeclaration[]; | ||
export declare interface FunctionDeclarationsTool { | ||
function_declarations?: FunctionDeclaration[]; | ||
} | ||
export declare interface RetrievalTool { | ||
retrieval?: Retrieval; | ||
} | ||
export declare interface GoogleSearchRetrievalTool { | ||
googleSearchRetrieval?: GoogleSearchRetrieval; | ||
} | ||
export declare type Tool = | ||
| FunctionDeclarationsTool | ||
| RetrievalTool | ||
| GoogleSearchRetrievalTool; | ||
/** | ||
* Defines a retrieval tool that model can call to access external knowledge. | ||
* @property {VertexAISearch} - [vertexAiSearch] Set to use data source powered | ||
by Vertex AI Search. | ||
* @property {boolean} - [disableAttribution] Disable using the result from | ||
this tool in detecting grounding attribution. This does not affect how the | ||
result is given to the model for generation. | ||
*/ | ||
export declare interface Retrieval { | ||
vertexAiSearch?: VertexAISearch; | ||
disableAttribution?: boolean; | ||
} | ||
/** | ||
* Tool to retrieve public web data for grounding, powered by Google. | ||
* @property {boolean} - [disableAttribution] Disable using the result from this | ||
* tool in detecting grounding attribution. This does not affect how the result | ||
* is given to the model for generation. | ||
*/ | ||
export declare interface GoogleSearchRetrieval { | ||
disableAttribution?: boolean; | ||
} | ||
/** | ||
* Retrieve from Vertex AI Search datastore for grounding. See | ||
https://cloud.google.com/vertex-ai-search-and-conversation | ||
* @property {string} - [datastore] Fully-qualified Vertex AI Search's datastore | ||
resource ID. projects/<>/locations/<>/collections/<>/dataStores/<> | ||
*/ | ||
export declare interface VertexAISearch { | ||
datastore: string; | ||
} | ||
/** | ||
* Contains the list of OpenAPI data types | ||
@@ -630,1 +756,27 @@ * as defined by https://swagger.io/docs/specification/data-models/data-types/ | ||
} | ||
/** | ||
* Params to initiate a multiturn chat with the model via startChat | ||
* @property {Content[]} - [history] history of the chat session. {@link Content} | ||
* @property {SafetySetting[]} - [safety_settings] Array of {@link SafetySetting} | ||
* @property {GenerationConfig} - [generation_config] {@link GenerationConfig} | ||
*/ | ||
export declare interface StartChatParams { | ||
history?: Content[]; | ||
safety_settings?: SafetySetting[]; | ||
generation_config?: GenerationConfig; | ||
tools?: Tool[]; | ||
api_endpoint?: string; | ||
} | ||
/** | ||
* All params passed to initiate multiturn chat via startChat | ||
* @property {string} project - project The Google Cloud project to use for the request | ||
* @property {string} location - The Google Cloud project location to use for the request | ||
*/ | ||
export declare interface StartChatSessionRequest extends StartChatParams { | ||
project: string; | ||
location: string; | ||
googleAuth: GoogleAuth; | ||
publisher_model_endpoint: string; | ||
} |
@@ -23,4 +23,12 @@ /** | ||
const USER_AGENT_PRODUCT = 'model-builder'; | ||
const CLIENT_LIBRARY_VERSION = '0.3.1'; // x-release-please-version | ||
const CLIENT_LIBRARY_VERSION = '0.4.0'; // x-release-please-version | ||
const CLIENT_LIBRARY_LANGUAGE = `grpc-node/${CLIENT_LIBRARY_VERSION}`; | ||
export const USER_AGENT = `${USER_AGENT_PRODUCT}/${CLIENT_LIBRARY_VERSION} ${CLIENT_LIBRARY_LANGUAGE}`; | ||
export const CREDENTIAL_ERROR_MESSAGE = | ||
'\nUnable to authenticate your request\ | ||
\nDepending on your run time environment, you can get authentication by\ | ||
\n- if in local instance or cloud shell: `!gcloud auth login`\ | ||
\n- if in Colab:\ | ||
\n -`from google.colab import auth`\ | ||
\n -`auth.authenticate_user()`\ | ||
\n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication'; |
@@ -19,2 +19,1 @@ /** | ||
export * as constants from './constants'; | ||
export {postRequest} from './post_request'; |
@@ -19,3 +19,9 @@ /** | ||
// @ts-ignore | ||
import {ClientError, TextPart, VertexAI} from '../src'; | ||
import { | ||
ClientError, | ||
FunctionDeclarationsTool, | ||
GoogleSearchRetrievalTool, | ||
TextPart, | ||
VertexAI, | ||
} from '../src'; | ||
import {FunctionDeclarationSchemaType} from '../src/types'; | ||
@@ -57,3 +63,3 @@ | ||
const TOOLS_WITH_FUNCTION_DECLARATION = [ | ||
const TOOLS_WITH_FUNCTION_DECLARATION: FunctionDeclarationsTool[] = [ | ||
{ | ||
@@ -80,2 +86,10 @@ function_declarations: [ | ||
const TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL: GoogleSearchRetrievalTool[] = [ | ||
{ | ||
googleSearchRetrieval: { | ||
disableAttribution: false, | ||
}, | ||
}, | ||
]; | ||
const WEATHER_FORECAST = 'super nice'; | ||
@@ -104,3 +118,3 @@ const FUNCTION_RESPONSE_PART = [ | ||
const generativeTextModel = vertex_ai.preview.getGenerativeModel({ | ||
const generativeTextModel = vertex_ai.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
@@ -111,3 +125,9 @@ generation_config: { | ||
}); | ||
const generativeTextModelWithPrefix = vertex_ai.preview.getGenerativeModel({ | ||
const generativeTextModelPreview = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const generativeTextModelWithPrefix = vertex_ai.getGenerativeModel({ | ||
model: 'models/gemini-pro', | ||
@@ -118,13 +138,28 @@ generation_config: { | ||
}); | ||
const textModelNoOutputLimit = vertex_ai.preview.getGenerativeModel({ | ||
const generativeTextModelWithPrefixPreview = | ||
vertex_ai.preview.getGenerativeModel({ | ||
model: 'models/gemini-pro', | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const textModelNoOutputLimit = vertex_ai.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeVisionModel = vertex_ai.preview.getGenerativeModel({ | ||
const textModelNoOutputLimitPreview = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeVisionModel = vertex_ai.getGenerativeModel({ | ||
model: 'gemini-pro-vision', | ||
}); | ||
const generativeVisionModelWithPrefix = vertex_ai.preview.getGenerativeModel({ | ||
const generativeVisionModelPreview = vertex_ai.preview.getGenerativeModel({ | ||
model: 'gemini-pro-vision', | ||
}); | ||
const generativeVisionModelWithPrefix = vertex_ai.getGenerativeModel({ | ||
model: 'models/gemini-pro-vision', | ||
}); | ||
const generativeVisionModelWithPrefixPreview = | ||
vertex_ai.preview.getGenerativeModel({ | ||
model: 'models/gemini-pro-vision', | ||
}); | ||
describe('generateContentStream', () => { | ||
@@ -150,2 +185,18 @@ beforeEach(() => { | ||
}); | ||
it('in preview should should return a stream and aggregated response when passed text', async () => { | ||
const streamingResp = | ||
await generativeTextModelPreview.generateContentStream(TEXT_REQUEST); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview, for item ${item}` | ||
); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
it('should not return a invalid unicode', async () => { | ||
@@ -175,2 +226,28 @@ const streamingResp = await generativeTextModel.generateContentStream({ | ||
}); | ||
it('in preview should not return a invalid unicode', async () => { | ||
const streamingResp = | ||
await generativeTextModelPreview.generateContentStream({ | ||
contents: [{role: 'user', parts: [{text: '创作一首古诗'}]}], | ||
}); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview, for item ${item}` | ||
); | ||
for (const candidate of item.candidates) { | ||
for (const part of candidate.content.parts as TextPart[]) { | ||
expect(part.text).not.toContain( | ||
'\ufffd', | ||
`sys test failure on generateContentStream in preview, for item ${item}` | ||
); | ||
} | ||
} | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
it('should return a stream and aggregated response when passed multipart base64 content', async () => { | ||
@@ -192,2 +269,20 @@ const streamingResp = await generativeVisionModel.generateContentStream( | ||
}); | ||
it('in preview should return a stream and aggregated response when passed multipart base64 content', async () => { | ||
const streamingResp = | ||
await generativeVisionModelPreview.generateContentStream( | ||
MULTI_PART_BASE64_REQUEST | ||
); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview, for item ${item}` | ||
); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
it('should throw ClientError when having invalid input', async () => { | ||
@@ -207,3 +302,3 @@ const badRequest = { | ||
expect(e).toBeInstanceOf(ClientError); | ||
expect(e.message).toBe( | ||
expect(e.message).toContain( | ||
'[VertexAI.ClientError]: got status: 400 Bad Request', | ||
@@ -215,2 +310,25 @@ `sys test failure on generateContentStream when having bad request | ||
}); | ||
it('in preview should throw ClientError when having invalid input', async () => { | ||
const badRequest = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [ | ||
{text: 'describe this image:'}, | ||
{inline_data: {mime_type: 'image/png', data: 'invalid data'}}, | ||
], | ||
}, | ||
], | ||
}; | ||
await generativeVisionModelPreview | ||
.generateContentStream(badRequest) | ||
.catch(e => { | ||
expect(e).toBeInstanceOf(ClientError); | ||
expect(e.message).toContain( | ||
'[VertexAI.ClientError]: got status: 400 Bad Request', | ||
`sys test failure on generateContentStream in preview when having bad request | ||
got wrong error message: ${e.message}` | ||
); | ||
}); | ||
}); | ||
@@ -233,2 +351,20 @@ it('should should return a stream and aggregated response when passed multipart GCS content', async () => { | ||
}); | ||
it('in preview should should return a stream and aggregated response when passed multipart GCS content', async () => { | ||
const streamingResp = | ||
await generativeVisionModelPreview.generateContentStream( | ||
MULTI_PART_GCS_REQUEST | ||
); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview, for item ${item}` | ||
); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
@@ -254,2 +390,22 @@ const request = { | ||
}); | ||
it('in preview should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const request = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]}, | ||
{role: 'model', parts: FUNCTION_CALL}, | ||
{role: 'function', parts: FUNCTION_RESPONSE_PART}, | ||
], | ||
tools: TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const streamingResp = | ||
await generativeTextModelPreview.generateContentStream(request); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview, for item ${item}` | ||
); | ||
expect(item.candidates[0].content.parts[0].text?.toLowerCase()).toContain( | ||
WEATHER_FORECAST | ||
); | ||
} | ||
}); | ||
}); | ||
@@ -264,3 +420,3 @@ | ||
const aggregatedResp = await response.response; | ||
const aggregatedResp = response.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy( | ||
@@ -270,2 +426,28 @@ `sys test failure on generateContentStream for aggregated response: ${aggregatedResp}` | ||
}); | ||
it('in preview should return the aggregated response', async () => { | ||
const response = | ||
await generativeTextModelPreview.generateContent(TEXT_REQUEST); | ||
const aggregatedResp = response.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
xit('should return grounding metadata when passed GoogleSearchRetriever or Retriever', async () => { | ||
const generativeTextModel = vertex_ai.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
//tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, | ||
}); | ||
const result = await generativeTextModel.generateContent({ | ||
contents: [{role: 'user', parts: [{text: 'Why is the sky blue?'}]}], | ||
tools: TOOLS_WITH_GOOGLE_SEARCH_RETRIEVAL, | ||
}); | ||
const response = result.response; | ||
const groundingMetadata = response.candidates[0].groundingMetadata; | ||
expect(groundingMetadata).toBeDefined(); | ||
if (groundingMetadata) { | ||
// expect(groundingMetadata.groundingAttributions).toBeTruthy(); | ||
expect(groundingMetadata.webSearchQueries).toBeTruthy(); | ||
} | ||
}); | ||
}); | ||
@@ -281,3 +463,3 @@ | ||
const result1 = await chat.sendMessage(chatInput1); | ||
const response1 = await result1.response; | ||
const response1 = result1.response; | ||
expect(response1.candidates[0]).toBeTruthy( | ||
@@ -288,2 +470,12 @@ `sys test failure on sendMessage for aggregated response: ${response1}` | ||
}); | ||
it('in preview should populate history and return a chat response', async () => { | ||
const chat = generativeTextModelPreview.startChat(); | ||
const chatInput1 = 'How can I learn more about Node.js?'; | ||
const result1 = await chat.sendMessage(chatInput1); | ||
const response1 = result1.response; | ||
expect(response1.candidates[0]).toBeTruthy( | ||
`sys test failure on sendMessage in preview for aggregated response: ${response1}` | ||
); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
}); | ||
@@ -314,2 +506,22 @@ | ||
}); | ||
it('in preview should should return a stream and populate history when generation_config is passed to startChat', async () => { | ||
const chat = generativeTextModelPreview.startChat({ | ||
generation_config: { | ||
max_output_tokens: 256, | ||
}, | ||
}); | ||
const chatInput1 = 'How can I learn more about Node.js?'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
for await (const item of result1.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on sendMessageStream in preview, for item ${item}` | ||
); | ||
} | ||
const resp = await result1.response; | ||
expect(resp.candidates[0]).toBeTruthy( | ||
`sys test failure on sendMessageStream in preview for aggregated response: ${resp}` | ||
); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
it('should should return a stream and populate history when startChat is passed no request obj', async () => { | ||
@@ -330,2 +542,18 @@ const chat = generativeTextModel.startChat(); | ||
}); | ||
it('in preview should should return a stream and populate history when startChat is passed no request obj', async () => { | ||
const chat = generativeTextModelPreview.startChat(); | ||
const chatInput1 = 'How can I learn more about Node.js?'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
for await (const item of result1.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on sendMessageStream in preview, for item ${item}` | ||
); | ||
} | ||
const resp = await result1.response; | ||
expect(resp.candidates[0]).toBeTruthy( | ||
`sys test failure on sendMessageStream in preview for aggregated response: ${resp}` | ||
); | ||
expect(chat.history.length).toBe(2); | ||
}); | ||
it('should return chunks as they come in', async () => { | ||
@@ -351,2 +579,23 @@ const chat = textModelNoOutputLimit.startChat({}); | ||
}); | ||
it('in preview should return chunks as they come in', async () => { | ||
const chat = textModelNoOutputLimitPreview.startChat({}); | ||
const chatInput1 = 'Tell me a story in 3000 words'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
let firstChunkTimestamp = 0; | ||
let aggregatedResultTimestamp = 0; | ||
const firstChunkFinalResultTimeDiff = 200; // ms | ||
for await (const item of result1.stream) { | ||
if (firstChunkTimestamp === 0) { | ||
firstChunkTimestamp = Date.now(); | ||
} | ||
} | ||
await result1.response; | ||
aggregatedResultTimestamp = Date.now(); | ||
expect(aggregatedResultTimestamp - firstChunkTimestamp).toBeGreaterThan( | ||
firstChunkFinalResultTimeDiff | ||
); | ||
}); | ||
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
@@ -383,2 +632,33 @@ const chat = generativeTextModel.startChat({ | ||
}); | ||
it('in preview should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const chat = generativeTextModelPreview.startChat({ | ||
tools: TOOLS_WITH_FUNCTION_DECLARATION, | ||
}); | ||
const chatInput1 = 'What is the weather in Boston?'; | ||
const result1 = await chat.sendMessageStream(chatInput1); | ||
for await (const item of result1.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on sendMessageStream in preview with function calling, for item ${item}` | ||
); | ||
} | ||
const response1 = await result1.response; | ||
expect( | ||
JSON.stringify(response1.candidates[0].content.parts[0].functionCall) | ||
).toContain(FUNCTION_CALL_NAME); | ||
expect( | ||
JSON.stringify(response1.candidates[0].content.parts[0].functionCall) | ||
).toContain('location'); | ||
// Send a follow up message with a FunctionResponse | ||
const result2 = await chat.sendMessageStream(FUNCTION_RESPONSE_PART); | ||
for await (const item of result2.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on sendMessageStream in preview with function calling, for item ${item}` | ||
); | ||
} | ||
const response2 = await result2.response; | ||
expect( | ||
JSON.stringify(response2.candidates[0].content.parts[0].text) | ||
).toContain(WEATHER_FORECAST); | ||
}); | ||
}); | ||
@@ -393,2 +673,9 @@ | ||
}); | ||
it('in preview should should return a CountTokensResponse', async () => { | ||
const countTokensResp = | ||
await generativeTextModelPreview.countTokens(TEXT_REQUEST); | ||
expect(countTokensResp.totalTokens).toBeTruthy( | ||
`sys test failure on countTokens in preview, ${countTokensResp}` | ||
); | ||
}); | ||
}); | ||
@@ -416,3 +703,20 @@ | ||
}); | ||
it('in preview should should return a stream and aggregated response when passed text', async () => { | ||
const streamingResp = | ||
await generativeTextModelWithPrefixPreview.generateContentStream( | ||
TEXT_REQUEST | ||
); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview using models/gemini-pro, for item ${item}` | ||
); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview using models/gemini-pro for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
it('should return a stream and aggregated response when passed multipart base64 content when using models/gemini-pro-vision', async () => { | ||
@@ -435,2 +739,19 @@ const streamingResp = | ||
}); | ||
it('in preview should return a stream and aggregated response when passed multipart base64 content when using models/gemini-pro-vision', async () => { | ||
const streamingResp = | ||
await generativeVisionModelWithPrefixPreview.generateContentStream( | ||
MULTI_PART_BASE64_REQUEST | ||
); | ||
for await (const item of streamingResp.stream) { | ||
expect(item.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview using models/gemini-pro-vision, for item ${item}` | ||
); | ||
} | ||
const aggregatedResp = await streamingResp.response; | ||
expect(aggregatedResp.candidates[0]).toBeTruthy( | ||
`sys test failure on generateContentStream in preview using models/gemini-pro-vision for aggregated response: ${aggregatedResp}` | ||
); | ||
}); | ||
}); |
@@ -18,1132 +18,14 @@ /** | ||
/* tslint:disable */ | ||
import {VertexAI} from '../src/index'; | ||
import { | ||
ChatSession, | ||
GenerativeModel, | ||
StartChatParams, | ||
VertexAI, | ||
} from '../src/index'; | ||
import * as StreamFunctions from '../src/process_stream'; | ||
import { | ||
CountTokensRequest, | ||
FinishReason, | ||
FunctionDeclarationSchemaType, | ||
GenerateContentRequest, | ||
GenerateContentResponse, | ||
GenerateContentResult, | ||
HarmBlockThreshold, | ||
HarmCategory, | ||
HarmProbability, | ||
SafetyRating, | ||
SafetySetting, | ||
StreamGenerateContentResult, | ||
Tool, | ||
} from '../src/types/content'; | ||
import {GoogleAuthError} from '../src/types/errors'; | ||
import {constants} from '../src/util'; | ||
const PROJECT = 'test_project'; | ||
const LOCATION = 'test_location'; | ||
const TEST_CHAT_MESSSAGE_TEXT = 'How are you doing today?'; | ||
const TEST_USER_CHAT_MESSAGE = [ | ||
{role: constants.USER_ROLE, parts: [{text: TEST_CHAT_MESSSAGE_TEXT}]}, | ||
]; | ||
const TEST_TOKEN = 'testtoken'; | ||
const TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE = [ | ||
{ | ||
role: constants.USER_ROLE, | ||
parts: [ | ||
{text: TEST_CHAT_MESSSAGE_TEXT}, | ||
{ | ||
file_data: { | ||
file_uri: 'gs://test_bucket/test_image.jpeg', | ||
mime_type: 'image/jpeg', | ||
}, | ||
}, | ||
], | ||
}, | ||
]; | ||
const TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE = [ | ||
{ | ||
role: constants.USER_ROLE, | ||
parts: [ | ||
{text: TEST_CHAT_MESSSAGE_TEXT}, | ||
{file_data: {file_uri: 'test_image.jpeg', mime_type: 'image/jpeg'}}, | ||
], | ||
}, | ||
]; | ||
const TEST_SAFETY_SETTINGS: SafetySetting[] = [ | ||
{ | ||
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, | ||
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH, | ||
}, | ||
]; | ||
const TEST_SAFETY_RATINGS: SafetyRating[] = [ | ||
{ | ||
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, | ||
probability: HarmProbability.NEGLIGIBLE, | ||
}, | ||
]; | ||
const TEST_GENERATION_CONFIG = { | ||
candidate_count: 1, | ||
stop_sequences: ['hello'], | ||
}; | ||
const TEST_CANDIDATES = [ | ||
{ | ||
index: 1, | ||
content: { | ||
role: constants.MODEL_ROLE, | ||
parts: [{text: 'Im doing great! How are you?'}], | ||
}, | ||
finishReason: FinishReason.STOP, | ||
finishMessage: '', | ||
safetyRatings: TEST_SAFETY_RATINGS, | ||
citationMetadata: { | ||
citationSources: [ | ||
{ | ||
startIndex: 367, | ||
endIndex: 491, | ||
uri: 'https://www.numerade.com/ask/question/why-does-the-uncertainty-principle-make-it-impossible-to-predict-a-trajectory-for-the-clectron-95172/', | ||
}, | ||
], | ||
}, | ||
}, | ||
]; | ||
const TEST_MODEL_RESPONSE = { | ||
candidates: TEST_CANDIDATES, | ||
usage_metadata: {prompt_token_count: 0, candidates_token_count: 0}, | ||
}; | ||
const TEST_FUNCTION_CALL_RESPONSE = { | ||
functionCall: { | ||
name: 'get_current_weather', | ||
args: { | ||
location: 'LA', | ||
unit: 'fahrenheit', | ||
}, | ||
}, | ||
}; | ||
const TEST_CANDIDATES_WITH_FUNCTION_CALL = [ | ||
{ | ||
index: 1, | ||
content: { | ||
role: constants.MODEL_ROLE, | ||
parts: [TEST_FUNCTION_CALL_RESPONSE], | ||
}, | ||
finishReason: FinishReason.STOP, | ||
finishMessage: '', | ||
safetyRatings: TEST_SAFETY_RATINGS, | ||
}, | ||
]; | ||
const TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL = { | ||
candidates: TEST_CANDIDATES_WITH_FUNCTION_CALL, | ||
}; | ||
const TEST_FUNCTION_RESPONSE_PART = [ | ||
{ | ||
functionResponse: { | ||
name: 'get_current_weather', | ||
response: {name: 'get_current_weather', content: {weather: 'super nice'}}, | ||
}, | ||
}, | ||
]; | ||
const TEST_CANDIDATES_MISSING_ROLE = [ | ||
{ | ||
index: 1, | ||
content: {parts: [{text: 'Im doing great! How are you?'}]}, | ||
finish_reason: 0, | ||
finish_message: '', | ||
safety_ratings: TEST_SAFETY_RATINGS, | ||
}, | ||
]; | ||
const TEST_MODEL_RESPONSE_MISSING_ROLE = { | ||
candidates: TEST_CANDIDATES_MISSING_ROLE, | ||
usage_metadata: {prompt_token_count: 0, candidates_token_count: 0}, | ||
}; | ||
const TEST_EMPTY_MODEL_RESPONSE = { | ||
candidates: [], | ||
}; | ||
const TEST_ENDPOINT_BASE_PATH = 'test.googleapis.com'; | ||
const TEST_GCS_FILENAME = 'gs://test_bucket/test_image.jpeg'; | ||
const TEST_MULTIPART_MESSAGE = [ | ||
{ | ||
role: constants.USER_ROLE, | ||
parts: [ | ||
{text: 'What is in this picture?'}, | ||
{file_data: {file_uri: TEST_GCS_FILENAME, mime_type: 'image/jpeg'}}, | ||
], | ||
}, | ||
]; | ||
const BASE_64_IMAGE = | ||
'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='; | ||
const INLINE_DATA_FILE_PART = { | ||
inline_data: { | ||
data: BASE_64_IMAGE, | ||
mime_type: 'image/jpeg', | ||
}, | ||
}; | ||
const TEST_MULTIPART_MESSAGE_BASE64 = [ | ||
{ | ||
role: constants.USER_ROLE, | ||
parts: [{text: 'What is in this picture?'}, INLINE_DATA_FILE_PART], | ||
}, | ||
]; | ||
const TEST_TOOLS_WITH_FUNCTION_DECLARATION: Tool[] = [ | ||
{ | ||
function_declarations: [ | ||
{ | ||
name: 'get_current_weather', | ||
description: 'get weather in a given location', | ||
parameters: { | ||
type: FunctionDeclarationSchemaType.OBJECT, | ||
properties: { | ||
location: {type: FunctionDeclarationSchemaType.STRING}, | ||
unit: { | ||
type: FunctionDeclarationSchemaType.STRING, | ||
enum: ['celsius', 'fahrenheit'], | ||
}, | ||
}, | ||
required: ['location'], | ||
}, | ||
}, | ||
], | ||
}, | ||
]; | ||
const fetchResponseObj = { | ||
status: 200, | ||
statusText: 'OK', | ||
ok: true, | ||
headers: {'Content-Type': 'application/json'}, | ||
url: 'url', | ||
}; | ||
/** | ||
* Returns a generator, used to mock the generateContentStream response | ||
* @ignore | ||
*/ | ||
export async function* testGenerator(): AsyncGenerator<GenerateContentResponse> { | ||
yield { | ||
candidates: TEST_CANDIDATES, | ||
}; | ||
} | ||
export async function* testGeneratorWithEmptyResponse(): AsyncGenerator<GenerateContentResponse> { | ||
yield { | ||
candidates: [], | ||
}; | ||
} | ||
describe('VertexAI', () => { | ||
let vertexai: VertexAI; | ||
let model: GenerativeModel; | ||
let expectedStreamResult: StreamGenerateContentResult; | ||
let fetchSpy: jasmine.Spy; | ||
beforeEach(() => { | ||
vertexai = new VertexAI({ | ||
describe('SDK', () => { | ||
it('should import VertexAI', () => { | ||
const PROJECT = 'test_project'; | ||
const LOCATION = 'test_location'; | ||
const vertexai = new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
const fetchResult = new Response( | ||
JSON.stringify(expectedStreamResult), | ||
fetchResponseObj | ||
); | ||
fetchSpy = spyOn(global, 'fetch').and.resolveTo(fetchResult); | ||
}); | ||
it('given undefined google auth options, should be instantiated', () => { | ||
expect(vertexai).toBeInstanceOf(VertexAI); | ||
}); | ||
it('given specified google auth options, should be instantiated', () => { | ||
const googleAuthOptions = { | ||
scopes: 'https://www.googleapis.com/auth/cloud-platform', | ||
}; | ||
const vetexai1 = new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
googleAuthOptions: googleAuthOptions, | ||
}); | ||
expect(vetexai1).toBeInstanceOf(VertexAI); | ||
}); | ||
it('given inconsistent project ID, should throw error', () => { | ||
const googleAuthOptions = { | ||
projectId: 'another_project', | ||
}; | ||
expect(() => { | ||
new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
googleAuthOptions: googleAuthOptions, | ||
}); | ||
}).toThrow( | ||
new Error( | ||
'inconsistent project ID values. argument project got value test_project but googleAuthOptions.projectId got value another_project' | ||
) | ||
); | ||
}); | ||
it('given scopes missing required scope, should throw GoogleAuthError', () => { | ||
const invalidGoogleAuthOptionsStringScopes = {scopes: 'test.scopes'}; | ||
expect(() => { | ||
new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
googleAuthOptions: invalidGoogleAuthOptionsStringScopes, | ||
}); | ||
}).toThrow( | ||
new GoogleAuthError( | ||
"input GoogleAuthOptions.scopes test.scopes doesn't contain required scope " + | ||
'https://www.googleapis.com/auth/cloud-platform, ' + | ||
'please include https://www.googleapis.com/auth/cloud-platform into GoogleAuthOptions.scopes ' + | ||
'or leave GoogleAuthOptions.scopes undefined' | ||
) | ||
); | ||
const invalidGoogleAuthOptionsArrayScopes = { | ||
scopes: ['test1.scopes', 'test2.scopes'], | ||
}; | ||
expect(() => { | ||
new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
googleAuthOptions: invalidGoogleAuthOptionsArrayScopes, | ||
}); | ||
}).toThrow( | ||
new GoogleAuthError( | ||
"input GoogleAuthOptions.scopes test1.scopes,test2.scopes doesn't contain required scope " + | ||
'https://www.googleapis.com/auth/cloud-platform, ' + | ||
'please include https://www.googleapis.com/auth/cloud-platform into GoogleAuthOptions.scopes ' + | ||
'or leave GoogleAuthOptions.scopes undefined' | ||
) | ||
); | ||
}); | ||
describe('generateContent', () => { | ||
it('returns a GenerateContentResponse', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
const resp = await model.generateContent(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a string', async () => { | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
const resp = await model.generateContent(TEST_CHAT_MESSSAGE_TEXT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a GCS URI', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
const resp = await model.generateContent(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('raises an error when passed an invalid GCS URI', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_INVALID_GCS_FILE, | ||
}; | ||
await expectAsync(model.generateContent(req)).toBeRejectedWithError( | ||
URIError | ||
); | ||
}); | ||
it('returns a GenerateContentResponse when passed safety_settings and generation_config', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
safety_settings: TEST_SAFETY_SETTINGS, | ||
generation_config: TEST_GENERATION_CONFIG, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
const resp = await model.generateContent(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('updates the base API endpoint when provided', async () => { | ||
const vertexaiWithBasePath = new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
apiEndpoint: TEST_ENDPOINT_BASE_PATH, | ||
}); | ||
model = vertexaiWithBasePath.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
await model.generateContent(req); | ||
expect(fetchSpy.calls.allArgs()[0][0].toString()).toContain( | ||
TEST_ENDPOINT_BASE_PATH | ||
); | ||
}); | ||
it('default the base API endpoint when base API not provided', async () => { | ||
const vertexaiWithoutBasePath = new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
model = vertexaiWithoutBasePath.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
await model.generateContent(req); | ||
expect(fetchSpy.calls.allArgs()[0][0].toString()).toContain( | ||
`${LOCATION}-aiplatform.googleapis.com` | ||
); | ||
}); | ||
it('removes top_k when it is set to 0', async () => { | ||
const reqWithEmptyConfigs: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
generation_config: {top_k: 0}, | ||
safety_settings: [], | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
await model.generateContent(reqWithEmptyConfigs); | ||
const requestArgs = fetchSpy.calls.allArgs()[0][1]; | ||
if (typeof requestArgs === 'object' && requestArgs) { | ||
expect(JSON.stringify(requestArgs['body'])).not.toContain('top_k'); | ||
} | ||
}); | ||
it('includes top_k when it is within 1 - 40', async () => { | ||
const reqWithEmptyConfigs: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE, | ||
generation_config: {top_k: 1}, | ||
safety_settings: [], | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
await model.generateContent(reqWithEmptyConfigs); | ||
const requestArgs = fetchSpy.calls.allArgs()[0][1]; | ||
if (typeof requestArgs === 'object' && requestArgs) { | ||
expect(JSON.stringify(requestArgs['body'])).toContain('top_k'); | ||
} | ||
}); | ||
it('aggregates citation metadata', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
const resp = await model.generateContent(req); | ||
expect( | ||
resp.response.candidates[0].citationMetadata?.citationSources.length | ||
).toEqual( | ||
TEST_MODEL_RESPONSE.candidates[0].citationMetadata.citationSources | ||
.length | ||
); | ||
}); | ||
it('returns a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
const resp = await model.generateContent(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await model.generateContent(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await model.generateContent(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('generateContentStream', () => { | ||
it('returns a GenerateContentResponse when passed text content', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const expectedResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContentStream(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed a string', async () => { | ||
const expectedResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContentStream(TEST_CHAT_MESSSAGE_TEXT); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed multi-part content with a GCS URI', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_MULTIPART_MESSAGE, | ||
}; | ||
const expectedResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContentStream(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a GenerateContentResponse when passed multi-part content with base64 data', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: TEST_MULTIPART_MESSAGE_BASE64, | ||
}; | ||
const expectedResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
const resp = await model.generateContentStream(req); | ||
expect(resp).toEqual(expectedResult); | ||
}); | ||
it('returns a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedStreamResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL), | ||
stream: testGenerator(), | ||
}; | ||
spyOn(StreamFunctions, 'processStream').and.returnValue( | ||
expectedStreamResult | ||
); | ||
const resp = await model.generateContentStream(req); | ||
expect(resp).toEqual(expectedStreamResult); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case1', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await model.generateContentStream(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throws ClientError when functionResponse is not immedidately following functionCall case2', async () => { | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]}, | ||
{ | ||
role: 'function', | ||
parts: TEST_FUNCTION_RESPONSE_PART, | ||
}, | ||
], | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: Please ensure that function response turn comes immediately after a function call turn.'; | ||
await model.generateContentStream(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('startChat', () => { | ||
it('returns a ChatSession when passed a request arg', () => { | ||
const req: StartChatParams = { | ||
history: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const resp = model.startChat(req); | ||
expect(resp).toBeInstanceOf(ChatSession); | ||
}); | ||
it('returns a ChatSession when passed no request arg', () => { | ||
const resp = model.startChat(); | ||
expect(resp).toBeInstanceOf(ChatSession); | ||
}); | ||
}); | ||
}); | ||
describe('countTokens', () => { | ||
it('returns the token count', async () => { | ||
const vertexai = new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
const req: CountTokensRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const responseBody = { | ||
totalTokens: 1, | ||
}; | ||
const response = new Response( | ||
JSON.stringify(responseBody), | ||
fetchResponseObj | ||
); | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
const resp = await model.countTokens(req); | ||
expect(resp).toEqual(responseBody); | ||
}); | ||
}); | ||
describe('ChatSession', () => { | ||
let chatSession: ChatSession; | ||
let chatSessionWithNoArgs: ChatSession; | ||
let chatSessionWithEmptyResponse: ChatSession; | ||
let chatSessionWithFunctionCall: ChatSession; | ||
let vertexai: VertexAI; | ||
let model: GenerativeModel; | ||
let expectedStreamResult: StreamGenerateContentResult; | ||
beforeEach(() => { | ||
vertexai = new VertexAI({project: PROJECT, location: LOCATION}); | ||
model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); | ||
chatSession = model.startChat({ | ||
history: TEST_USER_CHAT_MESSAGE, | ||
}); | ||
expect(chatSession.history).toEqual(TEST_USER_CHAT_MESSAGE); | ||
chatSessionWithNoArgs = model.startChat(); | ||
chatSessionWithEmptyResponse = model.startChat(); | ||
chatSessionWithFunctionCall = model.startChat({ | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}); | ||
expectedStreamResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
const fetchResult = Promise.resolve( | ||
new Response(JSON.stringify(expectedStreamResult), fetchResponseObj) | ||
); | ||
spyOn(global, 'fetch').and.returnValue(fetchResult); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
}); | ||
describe('sendMessage', () => { | ||
it('returns a GenerateContentResponse and appends to history', async () => { | ||
const req = 'How are you doing today?'; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
const resp = await chatSession.sendMessage(req); | ||
expect(resp).toEqual(expectedResult); | ||
expect(chatSession.history.length).toEqual(3); | ||
}); | ||
it('returns a GenerateContentResponse and appends to history when startChat is passed with no args', async () => { | ||
const req = 'How are you doing today?'; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
const resp = await chatSessionWithNoArgs.sendMessage(req); | ||
expect(resp).toEqual(expectedResult); | ||
expect(chatSessionWithNoArgs.history.length).toEqual(2); | ||
}); | ||
it('throws an error when the model returns an empty response', async () => { | ||
const req = 'How are you doing today?'; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_EMPTY_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
await expectAsync( | ||
chatSessionWithEmptyResponse.sendMessage(req) | ||
).toBeRejected(); | ||
expect(chatSessionWithEmptyResponse.history.length).toEqual(0); | ||
}); | ||
it('returns a GenerateContentResponse when passed multi-part content', async () => { | ||
const req = TEST_MULTIPART_MESSAGE[0]['parts']; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
spyOn(StreamFunctions, 'processNonStream').and.returnValue( | ||
expectedResult | ||
); | ||
const resp = await chatSessionWithNoArgs.sendMessage(req); | ||
expect(resp).toEqual(expectedResult); | ||
console.log(chatSessionWithNoArgs.history, 'hihii'); | ||
expect(chatSessionWithNoArgs.history.length).toEqual(2); | ||
}); | ||
it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => { | ||
const functionCallChatMessage = 'What is the weather in LA?'; | ||
const expectedFunctionCallResponse: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL, | ||
}; | ||
const expectedResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL, | ||
}; | ||
const streamSpy = spyOn(StreamFunctions, 'processNonStream'); | ||
streamSpy.and.returnValue(expectedResult); | ||
const response1 = await chatSessionWithFunctionCall.sendMessage( | ||
functionCallChatMessage | ||
); | ||
expect(response1).toEqual(expectedFunctionCallResponse); | ||
expect(chatSessionWithFunctionCall.history.length).toEqual(2); | ||
// Send a follow-up message with a FunctionResponse | ||
const expectedFollowUpResponse: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
const expectedFollowUpResult: GenerateContentResult = { | ||
response: TEST_MODEL_RESPONSE, | ||
}; | ||
streamSpy.and.returnValue(expectedFollowUpResult); | ||
const response2 = await chatSessionWithFunctionCall.sendMessage( | ||
TEST_FUNCTION_RESPONSE_PART | ||
); | ||
expect(response2).toEqual(expectedFollowUpResponse); | ||
expect(chatSessionWithFunctionCall.history.length).toEqual(4); | ||
}); | ||
it('throw ClientError when request has no content', async () => { | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: No content is provided for sending chat message.'; | ||
await chatSessionWithNoArgs.sendMessage([]).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throw ClientError when request mix functionCall part with other types of part', async () => { | ||
const chatRequest = [ | ||
'what is the weather like in LA', | ||
TEST_FUNCTION_RESPONSE_PART[0], | ||
]; | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.'; | ||
await chatSessionWithNoArgs.sendMessage(chatRequest).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('sendMessageStream', () => { | ||
it('returns a StreamGenerateContentResponse and appends to history', async () => { | ||
const req = 'How are you doing today?'; | ||
const expectedResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
const chatSession = model.startChat({ | ||
history: [ | ||
{ | ||
role: constants.USER_ROLE, | ||
parts: [{text: 'How are you doing today?'}], | ||
}, | ||
], | ||
}); | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
expect(chatSession.history.length).toEqual(1); | ||
expect(chatSession.history[0].role).toEqual(constants.USER_ROLE); | ||
const result = await chatSession.sendMessageStream(req); | ||
const response = await result.response; | ||
const expectedResponse = await expectedResult.response; | ||
expect(response).toEqual(expectedResponse); | ||
expect(chatSession.history.length).toEqual(3); | ||
expect(chatSession.history[0].role).toEqual(constants.USER_ROLE); | ||
expect(chatSession.history[1].role).toEqual(constants.USER_ROLE); | ||
expect(chatSession.history[2].role).toEqual(constants.MODEL_ROLE); | ||
}); | ||
it('returns a StreamGenerateContentResponse and appends role if missing', async () => { | ||
const req = 'How are you doing today?'; | ||
const expectedResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE_MISSING_ROLE), | ||
stream: testGenerator(), | ||
}; | ||
const chatSession = model.startChat({ | ||
history: [ | ||
{ | ||
role: constants.USER_ROLE, | ||
parts: [{text: 'How are you doing today?'}], | ||
}, | ||
], | ||
}); | ||
spyOn(StreamFunctions, 'processStream').and.returnValue(expectedResult); | ||
expect(chatSession.history.length).toEqual(1); | ||
expect(chatSession.history[0].role).toEqual(constants.USER_ROLE); | ||
const result = await chatSession.sendMessageStream(req); | ||
const response = await result.response; | ||
const expectedResponse = await expectedResult.response; | ||
expect(response).toEqual(expectedResponse); | ||
expect(chatSession.history.length).toEqual(3); | ||
expect(chatSession.history[0].role).toEqual(constants.USER_ROLE); | ||
expect(chatSession.history[1].role).toEqual(constants.USER_ROLE); | ||
expect(chatSession.history[2].role).toEqual(constants.MODEL_ROLE); | ||
}); | ||
it('returns a FunctionCall and appends to history when passed a FunctionDeclaration', async () => { | ||
const functionCallChatMessage = 'What is the weather in LA?'; | ||
const expectedStreamResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL), | ||
stream: testGenerator(), | ||
}; | ||
const streamSpy = spyOn(StreamFunctions, 'processStream'); | ||
streamSpy.and.returnValue(expectedStreamResult); | ||
const response1 = await chatSessionWithFunctionCall.sendMessageStream( | ||
functionCallChatMessage | ||
); | ||
expect(response1).toEqual(expectedStreamResult); | ||
expect(chatSessionWithFunctionCall.history.length).toEqual(2); | ||
// Send a follow-up message with a FunctionResponse | ||
const expectedFollowUpStreamResult: StreamGenerateContentResult = { | ||
response: Promise.resolve(TEST_MODEL_RESPONSE), | ||
stream: testGenerator(), | ||
}; | ||
streamSpy.and.returnValue(expectedFollowUpStreamResult); | ||
const response2 = await chatSessionWithFunctionCall.sendMessageStream( | ||
TEST_FUNCTION_RESPONSE_PART | ||
); | ||
expect(response2).toEqual(expectedFollowUpStreamResult); | ||
expect(chatSessionWithFunctionCall.history.length).toEqual(4); | ||
}); | ||
it('throw ClientError when request has no content', async () => { | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: No content is provided for sending chat message.'; | ||
await chatSessionWithNoArgs.sendMessageStream([]).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('throw ClientError when request mix functionCall part with other types of part', async () => { | ||
const chatRequest = [ | ||
'what is the weather like in LA', | ||
TEST_FUNCTION_RESPONSE_PART[0], | ||
]; | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: Within a single message, FunctionResponse cannot be mixed with other type of part in the request for sending chat message.'; | ||
await chatSessionWithNoArgs.sendMessageStream(chatRequest).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
}); | ||
describe('when exception at fetch', () => { | ||
const expectedErrorMessage = | ||
'[VertexAI.GoogleGenerativeAIError]: exception posting request'; | ||
const vertexai = new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); | ||
const chatSession = model.startChat(); | ||
const message = 'hi'; | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const countTokenReq: CountTokensRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
beforeEach(() => { | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
spyOn(global, 'fetch').and.throwError('error'); | ||
}); | ||
it('generateContent should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.generateContent(req)).toBeRejected(); | ||
}); | ||
it('generateContentStream should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.generateContentStream(req)).toBeRejected(); | ||
}); | ||
it('sendMessage should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(chatSession.sendMessage(message)).toBeRejected(); | ||
}); | ||
it('countTokens should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); | ||
}); | ||
}); | ||
describe('when response is undefined', () => { | ||
const expectedErrorMessage = | ||
'[VertexAI.GoogleGenerativeAIError]: response is undefined'; | ||
const vertexai = new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const message = 'hi'; | ||
const chatSession = model.startChat(); | ||
const countTokenReq: CountTokensRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
beforeEach(() => { | ||
spyOn(global, 'fetch').and.resolveTo(); | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
}); | ||
it('generateContent should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.generateContent(req)).toBeRejected(); | ||
await model.generateContent(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('generateContentStream should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.generateContentStream(req)).toBeRejected(); | ||
await model.generateContentStream(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('sendMessage should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(chatSession.sendMessage(message)).toBeRejected(); | ||
await chatSession.sendMessage(message).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('countTokens should throw GoogleGenerativeAI error', async () => { | ||
await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); | ||
await model.countTokens(countTokenReq).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('when response is 4XX', () => { | ||
const expectedErrorMessage = | ||
'[VertexAI.ClientError]: got status: 400 Bad Request'; | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const vertexai = new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
const fetch400Obj = { | ||
status: 400, | ||
statusText: 'Bad Request', | ||
ok: false, | ||
}; | ||
const body = {}; | ||
const response = new Response(JSON.stringify(body), fetch400Obj); | ||
const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); | ||
const message = 'hi'; | ||
const chatSession = model.startChat(); | ||
const countTokenReq: CountTokensRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
beforeEach(() => { | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
}); | ||
it('generateContent should throw ClientError error', async () => { | ||
await expectAsync(model.generateContent(req)).toBeRejected(); | ||
await model.generateContent(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('generateContentStream should throw ClientError error', async () => { | ||
await expectAsync(model.generateContentStream(req)).toBeRejected(); | ||
await model.generateContentStream(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('sendMessage should throw ClientError error', async () => { | ||
await expectAsync(chatSession.sendMessage(message)).toBeRejected(); | ||
await chatSession.sendMessage(message).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('countTokens should throw ClientError error', async () => { | ||
await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); | ||
await model.countTokens(countTokenReq).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); | ||
describe('when response is not OK and not 4XX', () => { | ||
const expectedErrorMessage = | ||
'[VertexAI.GoogleGenerativeAIError]: got status: 500 Internal Server Error'; | ||
const req: GenerateContentRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
const vertexai = new VertexAI({ | ||
project: PROJECT, | ||
location: LOCATION, | ||
}); | ||
const fetch500Obj = { | ||
status: 500, | ||
statusText: 'Internal Server Error', | ||
ok: false, | ||
}; | ||
const body = {}; | ||
const response = new Response(JSON.stringify(body), fetch500Obj); | ||
const model = vertexai.preview.getGenerativeModel({model: 'gemini-pro'}); | ||
const message = 'hi'; | ||
const chatSession = model.startChat(); | ||
const countTokenReq: CountTokensRequest = { | ||
contents: TEST_USER_CHAT_MESSAGE, | ||
}; | ||
beforeEach(() => { | ||
spyOnProperty(model, 'token', 'get').and.resolveTo(TEST_TOKEN); | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
}); | ||
it('generateContent should throws GoogleGenerativeAIError', async () => { | ||
await expectAsync(model.generateContent(req)).toBeRejected(); | ||
await model.generateContent(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('generateContentStream should throws GoogleGenerativeAIError', async () => { | ||
await expectAsync(model.generateContentStream(req)).toBeRejected(); | ||
await model.generateContentStream(req).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('sendMessage should throws GoogleGenerativeAIError', async () => { | ||
await expectAsync(chatSession.sendMessage(message)).toBeRejected(); | ||
await chatSession.sendMessage(message).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
it('countTokens should throws GoogleGenerativeAIError', async () => { | ||
await expectAsync(model.countTokens(countTokenReq)).toBeRejected(); | ||
await model.countTokens(countTokenReq).catch(e => { | ||
expect(e.message).toEqual(expectedErrorMessage); | ||
}); | ||
}); | ||
}); |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
Major refactor
Supply chain riskPackage has recently undergone a major refactor. It may be unstable or indicate significant internal changes. Use caution when updating to versions that include significant changes.
Found 1 instance in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
656675
120
11519