@google-cloud/vertexai
Advanced tools
Comparing version 1.1.0 to 1.2.0
{ | ||
".": "1.1.0" | ||
".": "1.2.0" | ||
} |
@@ -23,2 +23,2 @@ /** | ||
*/ | ||
export declare function countTokens(location: string, project: string, publisherModelEndpoint: string, token: Promise<string | null | undefined>, request: CountTokensRequest, apiEndpoint?: string, requestOptions?: RequestOptions): Promise<CountTokensResponse>; | ||
export declare function countTokens(location: string, resourcePath: string, token: Promise<string | null | undefined>, request: CountTokensRequest, apiEndpoint?: string, requestOptions?: RequestOptions): Promise<CountTokensResponse>; |
@@ -21,2 +21,3 @@ "use strict"; | ||
const errors_1 = require("../types/errors"); | ||
const constants = require("../util/constants"); | ||
const post_fetch_processing_1 = require("./post_fetch_processing"); | ||
@@ -29,8 +30,7 @@ const post_request_1 = require("./post_request"); | ||
*/ | ||
async function countTokens(location, project, publisherModelEndpoint, token, request, apiEndpoint, requestOptions) { | ||
async function countTokens(location, resourcePath, token, request, apiEndpoint, requestOptions) { | ||
const response = await (0, post_request_1.postRequest)({ | ||
region: location, | ||
project: project, | ||
resourcePath: publisherModelEndpoint, | ||
resourceMethod: 'countTokens', | ||
resourcePath: resourcePath, | ||
resourceMethod: constants.COUNT_TOKENS_METHOD, | ||
token: await token, | ||
@@ -37,0 +37,0 @@ data: request, |
@@ -23,3 +23,3 @@ /** | ||
import { GenerateContentRequest, GenerateContentResult, GenerationConfig, RequestOptions, SafetySetting, StreamGenerateContentResult, Tool } from '../types/content'; | ||
export declare function generateContent(location: string, project: string, publisherModelEndpoint: string, token: Promise<string | null | undefined>, request: GenerateContentRequest | string, apiEndpoint?: string, generationConfig?: GenerationConfig, safetySettings?: SafetySetting[], tools?: Tool[], requestOptions?: RequestOptions): Promise<GenerateContentResult>; | ||
export declare function generateContent(location: string, resourcePath: string, token: Promise<string | null | undefined>, request: GenerateContentRequest | string, apiEndpoint?: string, generationConfig?: GenerationConfig, safetySettings?: SafetySetting[], tools?: Tool[], requestOptions?: RequestOptions): Promise<GenerateContentResult>; | ||
/** | ||
@@ -32,2 +32,2 @@ * Make an async stream request to generate content. The response will be | ||
*/ | ||
export declare function generateContentStream(location: string, project: string, publisherModelEndpoint: string, token: Promise<string | null | undefined>, request: GenerateContentRequest | string, apiEndpoint?: string, generationConfig?: GenerationConfig, safetySettings?: SafetySetting[], tools?: Tool[], requestOptions?: RequestOptions): Promise<StreamGenerateContentResult>; | ||
export declare function generateContentStream(location: string, resourcePath: string, token: Promise<string | null | undefined>, request: GenerateContentRequest | string, apiEndpoint?: string, generationConfig?: GenerationConfig, safetySettings?: SafetySetting[], tools?: Tool[], requestOptions?: RequestOptions): Promise<StreamGenerateContentResult>; |
@@ -25,3 +25,3 @@ "use strict"; | ||
const pre_fetch_processing_1 = require("./pre_fetch_processing"); | ||
async function generateContent(location, project, publisherModelEndpoint, token, request, apiEndpoint, generationConfig, safetySettings, tools, requestOptions) { | ||
async function generateContent(location, resourcePath, token, request, apiEndpoint, generationConfig, safetySettings, tools, requestOptions) { | ||
var _a, _b, _c; | ||
@@ -42,9 +42,9 @@ request = (0, pre_fetch_processing_1.formatContentRequest)(request, generationConfig, safetySettings); | ||
region: location, | ||
project: project, | ||
resourcePath: publisherModelEndpoint, | ||
resourcePath, | ||
resourceMethod: constants.GENERATE_CONTENT_METHOD, | ||
token: await token, | ||
data: generateContentRequest, | ||
apiEndpoint: apiEndpoint, | ||
requestOptions: requestOptions, | ||
apiEndpoint, | ||
requestOptions, | ||
apiVersion: (0, pre_fetch_processing_1.getApiVersion)(request), | ||
}).catch(e => { | ||
@@ -66,3 +66,3 @@ throw new errors_1.GoogleGenerativeAIError('exception posting request to model', e); | ||
*/ | ||
async function generateContentStream(location, project, publisherModelEndpoint, token, request, apiEndpoint, generationConfig, safetySettings, tools, requestOptions) { | ||
async function generateContentStream(location, resourcePath, token, request, apiEndpoint, generationConfig, safetySettings, tools, requestOptions) { | ||
var _a, _b, _c; | ||
@@ -83,9 +83,9 @@ request = (0, pre_fetch_processing_1.formatContentRequest)(request, generationConfig, safetySettings); | ||
region: location, | ||
project: project, | ||
resourcePath: publisherModelEndpoint, | ||
resourcePath, | ||
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD, | ||
token: await token, | ||
data: generateContentRequest, | ||
apiEndpoint: apiEndpoint, | ||
requestOptions: requestOptions, | ||
apiEndpoint, | ||
requestOptions, | ||
apiVersion: (0, pre_fetch_processing_1.getApiVersion)(request), | ||
}).catch(e => { | ||
@@ -92,0 +92,0 @@ throw new errors_1.GoogleGenerativeAIError('exception posting request', e); |
@@ -234,2 +234,3 @@ "use strict"; | ||
groundingAttributions: [], | ||
retrievalQueries: [], | ||
}; | ||
@@ -246,2 +247,10 @@ const groundingMetadataAggregated = (_a = aggregatedCandidate.groundingMetadata) !== null && _a !== void 0 ? _a : emptyGroundingMetadata; | ||
} | ||
if (groundingMetadataChunk.retrievalQueries) { | ||
groundingMetadataAggregated.retrievalQueries = | ||
groundingMetadataAggregated.retrievalQueries.concat(groundingMetadataChunk.retrievalQueries); | ||
} | ||
if (groundingMetadataChunk.searchEntryPoint) { | ||
groundingMetadataAggregated.searchEntryPoint = | ||
groundingMetadataChunk.searchEntryPoint; | ||
} | ||
return groundingMetadataAggregated; | ||
@@ -248,0 +257,0 @@ } |
@@ -22,5 +22,4 @@ /** | ||
*/ | ||
export declare function postRequest({ region, project, resourcePath, resourceMethod, token, data, apiEndpoint, requestOptions, apiVersion, }: { | ||
export declare function postRequest({ region, resourcePath, resourceMethod, token, data, apiEndpoint, requestOptions, apiVersion, }: { | ||
region: string; | ||
project: string; | ||
resourcePath: string; | ||
@@ -27,0 +26,0 @@ resourceMethod: string; |
@@ -33,5 +33,5 @@ "use strict"; | ||
*/ | ||
async function postRequest({ region, project, resourcePath, resourceMethod, token, data, apiEndpoint, requestOptions, apiVersion = 'v1', }) { | ||
async function postRequest({ region, resourcePath, resourceMethod, token, data, apiEndpoint, requestOptions, apiVersion = 'v1', }) { | ||
const vertexBaseEndpoint = apiEndpoint !== null && apiEndpoint !== void 0 ? apiEndpoint : `${region}-${API_BASE_PATH}`; | ||
let vertexEndpoint = `https://${vertexBaseEndpoint}/${apiVersion}/projects/${project}/locations/${region}/${resourcePath}:${resourceMethod}`; | ||
let vertexEndpoint = `https://${vertexBaseEndpoint}/${apiVersion}/${resourcePath}:${resourceMethod}`; | ||
// Use server sent events for streamGenerateContent | ||
@@ -38,0 +38,0 @@ if (resourceMethod === constants.STREAMING_GENERATE_CONTENT_METHOD) { |
@@ -21,1 +21,4 @@ /** | ||
export declare function validateGenerationConfig(generationConfig: GenerationConfig): GenerationConfig; | ||
export declare function getApiVersion(request: GenerateContentRequest): 'v1' | 'v1beta1'; | ||
export declare function hasVertexRagStore(request: GenerateContentRequest): boolean; | ||
export declare function hasVertexAISearch(request: GenerateContentRequest): boolean; |
@@ -19,3 +19,4 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.validateGenerationConfig = exports.validateGenerateContentRequest = exports.formatContentRequest = void 0; | ||
exports.hasVertexAISearch = exports.hasVertexRagStore = exports.getApiVersion = exports.validateGenerationConfig = exports.validateGenerateContentRequest = exports.formatContentRequest = void 0; | ||
const errors_1 = require("../types/errors"); | ||
const constants = require("../util/constants"); | ||
@@ -48,2 +49,5 @@ function formatContentRequest(request, generationConfig, safetySettings) { | ||
} | ||
if (hasVertexAISearch(request) && hasVertexRagStore(request)) { | ||
throw new errors_1.ClientError('Found both vertexAiSearch and vertexRagStore field are set in tool. Either set vertexAiSearch or vertexRagStore.'); | ||
} | ||
} | ||
@@ -60,2 +64,32 @@ exports.validateGenerateContentRequest = validateGenerateContentRequest; | ||
exports.validateGenerationConfig = validateGenerationConfig; | ||
function getApiVersion(request) { | ||
return hasVertexRagStore(request) ? 'v1beta1' : 'v1'; | ||
} | ||
exports.getApiVersion = getApiVersion; | ||
function hasVertexRagStore(request) { | ||
var _a; | ||
for (const tool of (_a = request === null || request === void 0 ? void 0 : request.tools) !== null && _a !== void 0 ? _a : []) { | ||
const retrieval = tool.retrieval; | ||
if (!retrieval) | ||
continue; | ||
if (retrieval.vertexRagStore) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
exports.hasVertexRagStore = hasVertexRagStore; | ||
function hasVertexAISearch(request) { | ||
var _a; | ||
for (const tool of (_a = request === null || request === void 0 ? void 0 : request.tools) !== null && _a !== void 0 ? _a : []) { | ||
const retrieval = tool.retrieval; | ||
if (!retrieval) | ||
continue; | ||
if (retrieval.vertexAiSearch) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
exports.hasVertexAISearch = hasVertexAISearch; | ||
//# sourceMappingURL=pre_fetch_processing.js.map |
@@ -25,5 +25,4 @@ "use strict"; | ||
const StreamFunctions = require("../post_fetch_processing"); | ||
const TEST_PROJECT = 'test-project'; | ||
const TEST_LOCATION = 'test-location'; | ||
const TEST_PUBLISHER_MODEL_ENDPOINT = 'test-publisher-model-endpoint'; | ||
const TEST_RESOURCE_PATH = 'test-resource-path'; | ||
const TEST_TOKEN = 'testtoken'; | ||
@@ -36,2 +35,8 @@ const TEST_TOKEN_PROMISE = Promise.resolve(TEST_TOKEN); | ||
]; | ||
const CONTENTS = [ | ||
{ | ||
role: 'user', | ||
parts: [{ text: 'What is the weater like in Boston?' }], | ||
}, | ||
]; | ||
const TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE = [ | ||
@@ -186,2 +191,7 @@ { | ||
]; | ||
const TEST_TOOLS_WITH_RAG = [ | ||
{ | ||
retrieval: { vertexRagStore: { ragResources: [{ ragCorpus: 'ragCorpus' }] } }, | ||
}, | ||
]; | ||
const fetchResponseObj = { | ||
@@ -218,3 +228,3 @@ status: 200, | ||
fetchSpy = spyOn(global, 'fetch').and.resolveTo(response); | ||
const resp = await (0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const resp = await (0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResponseBody); | ||
@@ -228,3 +238,3 @@ }); | ||
}); | ||
await expectAsync((0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT, TEST_REQUEST_OPTIONS)).toBeRejected(); | ||
await expectAsync((0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT, TEST_REQUEST_OPTIONS)).toBeRejected(); | ||
expect(fetchSpy.calls.allArgs()[0][1].signal).toBeInstanceOf(AbortSignal); | ||
@@ -245,3 +255,3 @@ }); | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
await expectAsync((0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT)).toBeRejected(); | ||
await expectAsync((0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT)).toBeRejected(); | ||
}); | ||
@@ -261,3 +271,3 @@ it('throw ClientError when not OK and 4XX', async () => { | ||
spyOn(global, 'fetch').and.resolveTo(response); | ||
await expectAsync((0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT)).toBeRejected(); | ||
await expectAsync((0, count_tokens_1.countTokens)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT)).toBeRejected(); | ||
}); | ||
@@ -279,3 +289,3 @@ }); | ||
}); | ||
await expectAsync((0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT, TEST_GENERATION_CONFIG, TEST_SAFETY_SETTINGS, TEST_EMPTY_TOOLS, TEST_REQUEST_OPTIONS)).toBeRejected(); | ||
await expectAsync((0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT, TEST_GENERATION_CONFIG, TEST_SAFETY_SETTINGS, TEST_EMPTY_TOOLS, TEST_REQUEST_OPTIONS)).toBeRejected(); | ||
expect(fetchSpy.calls.allArgs()[0][1].signal).toBeInstanceOf(AbortSignal); | ||
@@ -291,3 +301,3 @@ }); | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
@@ -300,3 +310,3 @@ }); | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, TEST_CHAT_MESSAGE_TEXT, TEST_API_ENDPOINT); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, TEST_CHAT_MESSAGE_TEXT, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
@@ -312,3 +322,3 @@ }); | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
@@ -320,3 +330,3 @@ }); | ||
}; | ||
await expectAsync((0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT)).toBeRejectedWithError(URIError); | ||
await expectAsync((0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT)).toBeRejectedWithError(URIError); | ||
}); | ||
@@ -333,3 +343,3 @@ it('returns a GenerateContentResponse when passed safetySettings and generationConfig', async () => { | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
@@ -342,3 +352,3 @@ }); | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_ENDPOINT_BASE_PATH); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_ENDPOINT_BASE_PATH); | ||
expect(fetchSpy.calls.allArgs()[0][0].toString()).toContain(TEST_ENDPOINT_BASE_PATH); | ||
@@ -353,3 +363,3 @@ }); | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, reqWithEmptyConfigs, TEST_API_ENDPOINT); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, reqWithEmptyConfigs, TEST_API_ENDPOINT); | ||
const requestArgs = fetchSpy.calls.allArgs()[0][1]; | ||
@@ -367,3 +377,3 @@ if (typeof requestArgs === 'object' && requestArgs) { | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, reqWithEmptyConfigs, TEST_API_ENDPOINT); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, reqWithEmptyConfigs, TEST_API_ENDPOINT); | ||
const requestArgs = fetchSpy.calls.allArgs()[0][1]; | ||
@@ -380,3 +390,3 @@ if (typeof requestArgs === 'object' && requestArgs) { | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const resp = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect((_a = resp.response.candidates[0].citationMetadata) === null || _a === void 0 ? void 0 : _a.citations.length).toEqual(TEST_MODEL_RESPONSE.candidates[0].citationMetadata.citations.length); | ||
@@ -386,5 +396,3 @@ }); | ||
const req = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weater like in Boston?' }] }, | ||
], | ||
contents: CONTENTS, | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
@@ -396,3 +404,3 @@ }; | ||
fetchSpy.and.resolveTo(new Response(JSON.stringify(TEST_MODEL_RESPONSE_WITH_FUNCTION_CALL), fetchResponseObj)); | ||
const actualResult = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const actualResult = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(actualResult).toEqual(expectedResult); | ||
@@ -407,8 +415,3 @@ expect(types_1.GenerateContentResponseHandler.getFunctionCallsFromCandidate(actualResult.response.candidates[0])).toHaveSize(1); | ||
const req = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{ text: 'What is the weater like in Boston?' }], | ||
}, | ||
], | ||
contents: CONTENTS, | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
@@ -420,3 +423,3 @@ }; | ||
fetchSpy.and.resolveTo(new Response(JSON.stringify(TEST_MODEL_RESPONSE_WITH_INVALID_DATA), fetchResponseObj)); | ||
const actualResult = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const actualResult = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(actualResult).toEqual(expectedResult); | ||
@@ -427,14 +430,28 @@ expect(types_1.GenerateContentResponseHandler.getFunctionCallsFromCandidate((_a = actualResult.response.candidates) === null || _a === void 0 ? void 0 : _a[0])).toHaveSize(0); | ||
const req = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{ text: 'What is the weater like in Boston?' }], | ||
}, | ||
], | ||
contents: CONTENTS, | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
}; | ||
fetchSpy.and.resolveTo(new Response(JSON.stringify({}), fetchResponseObj)); | ||
const actualResult = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const actualResult = await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(actualResult.response.candidates).not.toBeDefined(); | ||
}); | ||
it('should use v1 apiVersion', async () => { | ||
const request = { | ||
contents: CONTENTS, | ||
}; | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, request, TEST_API_ENDPOINT); | ||
const vertexEndpoint = fetchSpy.calls.allArgs()[0][0]; | ||
expect(vertexEndpoint).toContain('/v1/'); | ||
}); | ||
it('should use v1beta1 apiVersion when set RAG in tools', async () => { | ||
const request = { | ||
contents: CONTENTS, | ||
tools: TEST_TOOLS_WITH_RAG, | ||
}; | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
await (0, generate_content_1.generateContent)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, request, TEST_API_ENDPOINT); | ||
const vertexEndpoint = fetchSpy.calls.allArgs()[0][0]; | ||
expect(vertexEndpoint).toContain('/v1beta1/'); | ||
}); | ||
}); | ||
@@ -461,3 +478,3 @@ describe('generateContentStream', () => { | ||
}); | ||
await expectAsync((0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT, TEST_GENERATION_CONFIG, TEST_SAFETY_SETTINGS, TEST_EMPTY_TOOLS, TEST_REQUEST_OPTIONS)).toBeRejected(); | ||
await expectAsync((0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT, TEST_GENERATION_CONFIG, TEST_SAFETY_SETTINGS, TEST_EMPTY_TOOLS, TEST_REQUEST_OPTIONS)).toBeRejected(); | ||
expect(fetchSpy.calls.allArgs()[0][1].signal).toBeInstanceOf(AbortSignal); | ||
@@ -475,3 +492,3 @@ }); | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
@@ -486,3 +503,3 @@ }); | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, TEST_API_ENDPOINT, TEST_CHAT_MESSAGE_TEXT); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, TEST_API_ENDPOINT, TEST_CHAT_MESSAGE_TEXT); | ||
expect(resp).toEqual(expectedResult); | ||
@@ -500,3 +517,3 @@ }); | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
@@ -514,3 +531,3 @@ }); | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedResult); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const resp = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(resp).toEqual(expectedResult); | ||
@@ -531,3 +548,3 @@ }); | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedStreamResult); | ||
const result = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const result = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(result).toEqual(expectedStreamResult); | ||
@@ -552,3 +569,3 @@ const response = await result.response; | ||
spyOn(StreamFunctions, 'processStream').and.resolveTo(expectedStreamResult); | ||
const actualResult = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_PROJECT, TEST_PUBLISHER_MODEL_ENDPOINT, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
const actualResult = await (0, generate_content_1.generateContentStream)(TEST_LOCATION, TEST_RESOURCE_PATH, TEST_TOKEN_PROMISE, req, TEST_API_ENDPOINT); | ||
expect(actualResult).toEqual(expectedStreamResult); | ||
@@ -555,0 +572,0 @@ const response = await actualResult.response; |
@@ -32,4 +32,3 @@ "use strict"; | ||
const LOCATION = 'location'; | ||
const PROJECT = 'project'; | ||
const PUBLISHER_MODEL_ENDPOINT = 'publisher_model_endpoint'; | ||
const RESOURCE_PATH = 'RESOURCE_PATH'; | ||
const TOKEN = Promise.resolve('token'); | ||
@@ -61,3 +60,3 @@ const GENERATE_CONTENT_REQUEST = 'generate_content_request'; | ||
spyOn(global, 'fetch').and.resolveTo(fetchResult); | ||
const actualResult = await (0, generate_content_1.generateContent)(LOCATION, PROJECT, PUBLISHER_MODEL_ENDPOINT, TOKEN, GENERATE_CONTENT_REQUEST); | ||
const actualResult = await (0, generate_content_1.generateContent)(LOCATION, RESOURCE_PATH, TOKEN, GENERATE_CONTENT_REQUEST); | ||
const actualResponse = actualResult.response; | ||
@@ -70,3 +69,3 @@ expect(actualResponse).toEqual(test_data_1.UNARY_RESPONSE_1); | ||
spyOn(global, 'fetch').and.resolveTo(fetchResult); | ||
const actualResult = await (0, generate_content_1.generateContent)(LOCATION, PROJECT, PUBLISHER_MODEL_ENDPOINT, TOKEN, GENERATE_CONTENT_REQUEST); | ||
const actualResult = await (0, generate_content_1.generateContent)(LOCATION, RESOURCE_PATH, TOKEN, GENERATE_CONTENT_REQUEST); | ||
const actualResponse = actualResult.response; | ||
@@ -77,3 +76,3 @@ expect(actualResponse).toEqual(expectedResult); | ||
spyOn(global, 'fetch').and.resolveTo(new Response(JSON.stringify({}), fetchResponseObj)); | ||
const actualResult = await (0, generate_content_1.generateContent)(LOCATION, PROJECT, PUBLISHER_MODEL_ENDPOINT, TOKEN, GENERATE_CONTENT_REQUEST); | ||
const actualResult = await (0, generate_content_1.generateContent)(LOCATION, RESOURCE_PATH, TOKEN, GENERATE_CONTENT_REQUEST); | ||
const actualResponse = actualResult.response; | ||
@@ -88,3 +87,3 @@ expect(actualResponse).toEqual({}); | ||
const processStreamSpy = spyOn(PostFetchFunctions, 'processStream'); | ||
await (0, generate_content_1.generateContentStream)(LOCATION, PROJECT, PUBLISHER_MODEL_ENDPOINT, TOKEN, GENERATE_CONTENT_REQUEST); | ||
await (0, generate_content_1.generateContentStream)(LOCATION, RESOURCE_PATH, TOKEN, GENERATE_CONTENT_REQUEST); | ||
const actualArg = processStreamSpy.calls.allArgs()[0][0]; | ||
@@ -100,3 +99,3 @@ expect(actualArg).toBeDefined(); | ||
spyOn(global, 'fetch').and.resolveTo(fetchResult); | ||
const actualResponse = await (0, count_tokens_1.countTokens)(LOCATION, PROJECT, PUBLISHER_MODEL_ENDPOINT, TOKEN, COUNT_TOKEN_REQUEST); | ||
const actualResponse = await (0, count_tokens_1.countTokens)(LOCATION, RESOURCE_PATH, TOKEN, COUNT_TOKEN_REQUEST); | ||
expect(actualResponse).toEqual(test_data_1.COUNT_TOKENS_RESPONSE_1); | ||
@@ -103,0 +102,0 @@ }); |
@@ -6,3 +6,2 @@ "use strict"; | ||
const REGION = 'us-central1'; | ||
const PROJECT = 'project-id'; | ||
const RESOURCE_PATH = 'resource-path'; | ||
@@ -26,3 +25,2 @@ const RESOURCE_METHOD = 'resource-method'; | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -46,3 +44,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -64,3 +61,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -84,3 +80,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -104,3 +99,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -122,3 +116,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -145,3 +138,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -171,3 +163,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -174,0 +165,0 @@ resourceMethod: RESOURCE_METHOD, |
@@ -308,2 +308,5 @@ "use strict"; | ||
], | ||
searchEntryPoint: { | ||
renderedContent: 'rendered content for later chunk for first candidate', | ||
}, | ||
}, | ||
@@ -330,2 +333,5 @@ }, | ||
], | ||
searchEntryPoint: { | ||
renderedContent: 'rendered content for later chunk for second candidate', | ||
}, | ||
}, | ||
@@ -415,2 +421,6 @@ }, | ||
], | ||
retrievalQueries: [], | ||
searchEntryPoint: { | ||
renderedContent: 'rendered content for later chunk for first candidate', | ||
}, | ||
}, | ||
@@ -486,2 +496,6 @@ finishReason: 'STOP', | ||
], | ||
retrievalQueries: [], | ||
searchEntryPoint: { | ||
renderedContent: 'rendered content for later chunk for second candidate', | ||
}, | ||
}, | ||
@@ -488,0 +502,0 @@ finishReason: 'STOP', |
@@ -31,3 +31,3 @@ /** | ||
private sendStreamPromise; | ||
private readonly publisherModelEndpoint; | ||
private readonly resourcePath; | ||
private readonly googleAuth; | ||
@@ -111,3 +111,3 @@ protected readonly requestOptions?: RequestOptions; | ||
private sendStreamPromise; | ||
private readonly publisherModelEndpoint; | ||
private readonly resourcePath; | ||
private readonly googleAuth; | ||
@@ -114,0 +114,0 @@ protected readonly requestOptions?: RequestOptions; |
@@ -20,5 +20,6 @@ "use strict"; | ||
exports.ChatSessionPreview = exports.ChatSession = void 0; | ||
const util_1 = require("./util"); | ||
const generate_content_1 = require("../functions/generate_content"); | ||
const errors_1 = require("../types/errors"); | ||
const util_1 = require("../util"); | ||
const util_2 = require("../util"); | ||
/** | ||
@@ -46,3 +47,3 @@ * The `ChatSession` class is used to make multiturn send message requests. You | ||
this.googleAuth = request.googleAuth; | ||
this.publisherModelEndpoint = request.publisherModelEndpoint; | ||
this.resourcePath = request.resourcePath; | ||
this.historyInternal = (_a = request.history) !== null && _a !== void 0 ? _a : []; | ||
@@ -55,5 +56,4 @@ this.generationConfig = request.generationConfig; | ||
if (request.systemInstruction) { | ||
request.systemInstruction.role = util_1.constants.SYSTEM_ROLE; | ||
this.systemInstruction = (0, util_1.formulateSystemInstructionIntoContent)(request.systemInstruction); | ||
} | ||
this.systemInstruction = request.systemInstruction; | ||
} | ||
@@ -67,3 +67,3 @@ /** | ||
const tokenPromise = this.googleAuth.getAccessToken().catch(e => { | ||
throw new errors_1.GoogleAuthError(util_1.constants.CREDENTIAL_ERROR_MESSAGE, e); | ||
throw new errors_1.GoogleAuthError(util_2.constants.CREDENTIAL_ERROR_MESSAGE, e); | ||
}); | ||
@@ -100,3 +100,3 @@ return tokenPromise; | ||
}; | ||
const generateContentResult = await (0, generate_content_1.generateContent)(this.location, this.project, this.publisherModelEndpoint, this.fetchToken(), generateContentrequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions).catch(e => { | ||
const generateContentResult = await (0, generate_content_1.generateContent)(this.location, this.resourcePath, this.fetchToken(), generateContentrequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions).catch(e => { | ||
throw e; | ||
@@ -157,7 +157,9 @@ }); | ||
}; | ||
const streamGenerateContentResultPromise = (0, generate_content_1.generateContentStream)(this.location, this.project, this.publisherModelEndpoint, this.fetchToken(), generateContentrequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions).catch(e => { | ||
const streamGenerateContentResultPromise = (0, generate_content_1.generateContentStream)(this.location, this.resourcePath, this.fetchToken(), generateContentrequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions).catch(e => { | ||
throw e; | ||
}); | ||
this.sendStreamPromise = this.appendHistory(streamGenerateContentResultPromise, newContent).catch(e => { | ||
throw new errors_1.GoogleGenerativeAIError('exception appending chat history', e); | ||
// Errors from remote endpoint will be catchable by user from streamGenerateContentResultPromise | ||
// Errors in appendHistory should not throw to cause user's programe exit with code 1 | ||
console.error(e); | ||
}); | ||
@@ -190,3 +192,3 @@ return streamGenerateContentResultPromise; | ||
this.googleAuth = request.googleAuth; | ||
this.publisherModelEndpoint = request.publisherModelEndpoint; | ||
this.resourcePath = request.resourcePath; | ||
this.historyInternal = (_a = request.history) !== null && _a !== void 0 ? _a : []; | ||
@@ -199,5 +201,4 @@ this.generationConfig = request.generationConfig; | ||
if (request.systemInstruction) { | ||
request.systemInstruction.role = util_1.constants.SYSTEM_ROLE; | ||
this.systemInstruction = (0, util_1.formulateSystemInstructionIntoContent)(request.systemInstruction); | ||
} | ||
this.systemInstruction = request.systemInstruction; | ||
} | ||
@@ -210,3 +211,3 @@ /** | ||
const tokenPromise = this.googleAuth.getAccessToken().catch(e => { | ||
throw new errors_1.GoogleAuthError(util_1.constants.CREDENTIAL_ERROR_MESSAGE, e); | ||
throw new errors_1.GoogleAuthError(util_2.constants.CREDENTIAL_ERROR_MESSAGE, e); | ||
}); | ||
@@ -243,3 +244,3 @@ return tokenPromise; | ||
}; | ||
const generateContentResult = await (0, generate_content_1.generateContent)(this.location, this.project, this.publisherModelEndpoint, this.fetchToken(), generateContentrequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions).catch(e => { | ||
const generateContentResult = await (0, generate_content_1.generateContent)(this.location, this.resourcePath, this.fetchToken(), generateContentrequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions).catch(e => { | ||
throw e; | ||
@@ -300,6 +301,10 @@ }); | ||
}; | ||
const streamGenerateContentResultPromise = (0, generate_content_1.generateContentStream)(this.location, this.project, this.publisherModelEndpoint, this.fetchToken(), generateContentrequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions).catch(e => { | ||
const streamGenerateContentResultPromise = (0, generate_content_1.generateContentStream)(this.location, this.resourcePath, this.fetchToken(), generateContentrequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions).catch(e => { | ||
throw e; | ||
}); | ||
this.sendStreamPromise = this.appendHistory(streamGenerateContentResultPromise, newContent); | ||
this.sendStreamPromise = this.appendHistory(streamGenerateContentResultPromise, newContent).catch(e => { | ||
// Errors from remote endpoint will be catchable by user from streamGenerateContentResultPromise | ||
// Errors in appendHistory should not throw to cause user's programe exit with code 1 | ||
console.error(e); | ||
}); | ||
return streamGenerateContentResultPromise; | ||
@@ -335,4 +340,4 @@ } | ||
function assignRoleToPartsAndValidateSendMessageRequest(parts) { | ||
const userContent = { role: util_1.constants.USER_ROLE, parts: [] }; | ||
const functionContent = { role: util_1.constants.USER_ROLE, parts: [] }; | ||
const userContent = { role: util_2.constants.USER_ROLE, parts: [] }; | ||
const functionContent = { role: util_2.constants.USER_ROLE, parts: [] }; | ||
let hasUserContent = false; | ||
@@ -339,0 +344,0 @@ let hasFunctionContent = false; |
@@ -36,2 +36,3 @@ /** | ||
private readonly publisherModelEndpoint; | ||
private readonly resourcePath; | ||
private readonly apiEndpoint?; | ||
@@ -155,2 +156,3 @@ /** | ||
private readonly publisherModelEndpoint; | ||
private readonly resourcePath; | ||
private readonly apiEndpoint?; | ||
@@ -157,0 +159,0 @@ /** |
@@ -20,6 +20,7 @@ "use strict"; | ||
exports.GenerativeModelPreview = exports.GenerativeModel = void 0; | ||
const util_1 = require("./util"); | ||
const count_tokens_1 = require("../functions/count_tokens"); | ||
const generate_content_1 = require("../functions/generate_content"); | ||
const errors_1 = require("../types/errors"); | ||
const util_1 = require("../util"); | ||
const util_2 = require("../util"); | ||
const chat_session_1 = require("./chat_session"); | ||
@@ -49,11 +50,7 @@ /** | ||
if (getGenerativeModelParams.systemInstruction) { | ||
getGenerativeModelParams.systemInstruction.role = util_1.constants.SYSTEM_ROLE; | ||
this.systemInstruction = (0, util_1.formulateSystemInstructionIntoContent)(getGenerativeModelParams.systemInstruction); | ||
} | ||
this.systemInstruction = getGenerativeModelParams.systemInstruction; | ||
if (this.model.startsWith('models/')) { | ||
this.publisherModelEndpoint = `publishers/google/${this.model}`; | ||
} | ||
else { | ||
this.publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
} | ||
this.resourcePath = formulateResourcePathFromModel(this.model, this.project, this.location); | ||
// publisherModelEndpoint is deprecated | ||
this.publisherModelEndpoint = this.resourcePath; | ||
} | ||
@@ -67,3 +64,3 @@ /** | ||
const tokenPromise = this.googleAuth.getAccessToken().catch(e => { | ||
throw new errors_1.GoogleAuthError(util_1.constants.CREDENTIAL_ERROR_MESSAGE, e); | ||
throw new errors_1.GoogleAuthError(util_2.constants.CREDENTIAL_ERROR_MESSAGE, e); | ||
}); | ||
@@ -93,3 +90,3 @@ return tokenPromise; | ||
const formulatedRequest = formulateSystemInstructionIntoGenerateContentRequest(request, this.systemInstruction); | ||
return (0, generate_content_1.generateContent)(this.location, this.project, this.publisherModelEndpoint, this.fetchToken(), formulatedRequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions); | ||
return (0, generate_content_1.generateContent)(this.location, this.resourcePath, this.fetchToken(), formulatedRequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions); | ||
} | ||
@@ -123,3 +120,3 @@ /** | ||
const formulatedRequest = formulateSystemInstructionIntoGenerateContentRequest(request, this.systemInstruction); | ||
return (0, generate_content_1.generateContentStream)(this.location, this.project, this.publisherModelEndpoint, this.fetchToken(), formulatedRequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions); | ||
return (0, generate_content_1.generateContentStream)(this.location, this.resourcePath, this.fetchToken(), formulatedRequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions); | ||
} | ||
@@ -145,3 +142,3 @@ /** | ||
async countTokens(request) { | ||
return (0, count_tokens_1.countTokens)(this.location, this.project, this.publisherModelEndpoint, this.fetchToken(), request, this.apiEndpoint, this.requestOptions); | ||
return (0, count_tokens_1.countTokens)(this.location, this.resourcePath, this.fetchToken(), request, this.apiEndpoint, this.requestOptions); | ||
} | ||
@@ -179,3 +176,5 @@ /** | ||
publisherModelEndpoint: this.publisherModelEndpoint, | ||
resourcePath: this.resourcePath, | ||
tools: this.tools, | ||
systemInstruction: this.systemInstruction, | ||
}; | ||
@@ -220,11 +219,7 @@ if (request) { | ||
if (getGenerativeModelParams.systemInstruction) { | ||
getGenerativeModelParams.systemInstruction.role = util_1.constants.SYSTEM_ROLE; | ||
this.systemInstruction = (0, util_1.formulateSystemInstructionIntoContent)(getGenerativeModelParams.systemInstruction); | ||
} | ||
this.systemInstruction = getGenerativeModelParams.systemInstruction; | ||
if (this.model.startsWith('models/')) { | ||
this.publisherModelEndpoint = `publishers/google/${this.model}`; | ||
} | ||
else { | ||
this.publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
} | ||
this.resourcePath = formulateResourcePathFromModel(this.model, this.project, this.location); | ||
// publisherModelEndpoint is deprecated | ||
this.publisherModelEndpoint = this.resourcePath; | ||
} | ||
@@ -238,3 +233,3 @@ /** | ||
const tokenPromise = this.googleAuth.getAccessToken().catch(e => { | ||
throw new errors_1.GoogleAuthError(util_1.constants.CREDENTIAL_ERROR_MESSAGE, e); | ||
throw new errors_1.GoogleAuthError(util_2.constants.CREDENTIAL_ERROR_MESSAGE, e); | ||
}); | ||
@@ -263,3 +258,3 @@ return tokenPromise; | ||
const formulatedRequest = formulateSystemInstructionIntoGenerateContentRequest(request, this.systemInstruction); | ||
return (0, generate_content_1.generateContent)(this.location, this.project, this.publisherModelEndpoint, this.fetchToken(), formulatedRequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions); | ||
return (0, generate_content_1.generateContent)(this.location, this.resourcePath, this.fetchToken(), formulatedRequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions); | ||
} | ||
@@ -293,3 +288,3 @@ /** | ||
const formulatedRequest = formulateSystemInstructionIntoGenerateContentRequest(request, this.systemInstruction); | ||
return (0, generate_content_1.generateContentStream)(this.location, this.project, this.publisherModelEndpoint, this.fetchToken(), formulatedRequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions); | ||
return (0, generate_content_1.generateContentStream)(this.location, this.resourcePath, this.fetchToken(), formulatedRequest, this.apiEndpoint, this.generationConfig, this.safetySettings, this.tools, this.requestOptions); | ||
} | ||
@@ -315,3 +310,3 @@ /** | ||
async countTokens(request) { | ||
return (0, count_tokens_1.countTokens)(this.location, this.project, this.publisherModelEndpoint, this.fetchToken(), request, this.apiEndpoint, this.requestOptions); | ||
return (0, count_tokens_1.countTokens)(this.location, this.resourcePath, this.fetchToken(), request, this.apiEndpoint, this.requestOptions); | ||
} | ||
@@ -349,3 +344,5 @@ /** | ||
publisherModelEndpoint: this.publisherModelEndpoint, | ||
resourcePath: this.resourcePath, | ||
tools: this.tools, | ||
systemInstruction: this.systemInstruction, | ||
}; | ||
@@ -366,6 +363,28 @@ if (request) { | ||
exports.GenerativeModelPreview = GenerativeModelPreview; | ||
function formulateResourcePathFromModel(model, project, location) { | ||
let resourcePath; | ||
if (!model) { | ||
throw new errors_1.ClientError('model parameter must not be empty.'); | ||
} | ||
if (!model.includes('/')) { | ||
// example 'gemini-1.0-pro' | ||
resourcePath = `projects/${project}/locations/${location}/publishers/google/models/${model}`; | ||
} | ||
else if (model.startsWith('models/')) { | ||
// example 'models/gemini-1.0-pro' | ||
resourcePath = `projects/${project}/locations/${location}/publishers/google/${model}`; | ||
} | ||
else if (model.startsWith('projects/')) { | ||
// example 'projects/my-project/locations/my-location/models/my-tuned-model' | ||
resourcePath = model; | ||
} | ||
else { | ||
throw new errors_1.ClientError('model parameter must be either a Model Garden model ID or a full resource name.'); | ||
} | ||
return resourcePath; | ||
} | ||
function formulateRequestToGenerateContentRequest(request) { | ||
if (typeof request === 'string') { | ||
return { | ||
contents: [{ role: util_1.constants.USER_ROLE, parts: [{ text: request }] }], | ||
contents: [{ role: util_2.constants.USER_ROLE, parts: [{ text: request }] }], | ||
}; | ||
@@ -377,3 +396,3 @@ } | ||
if (methodRequest.systemInstruction) { | ||
methodRequest.systemInstruction.role = util_1.constants.SYSTEM_ROLE; | ||
methodRequest.systemInstruction = (0, util_1.formulateSystemInstructionIntoContent)(methodRequest.systemInstruction); | ||
return methodRequest; | ||
@@ -380,0 +399,0 @@ } |
@@ -24,5 +24,9 @@ /** | ||
project: string; | ||
/** The Google Cloud project location. */ | ||
location: string; | ||
/** | ||
* Optional. The Google Cloud project location. If not provided, SDK will | ||
* firtly try to resolve it from run time environment. If no location resolved | ||
* from run time environment, SDK will use default value `us-central1`. | ||
*/ | ||
location?: string; | ||
/** | ||
* Optional. The base Vertex AI endpoint to use for the request. If not | ||
@@ -47,6 +51,7 @@ * provided, the default regionalized endpoint (i.e. | ||
contents: Content[]; | ||
/** Optional. The user provided system instructions for the model. | ||
/** | ||
* Optional. The user provided system instructions for the model. | ||
* Note: only text should be used in parts of {@link Content} | ||
*/ | ||
systemInstruction?: Content; | ||
systemInstruction?: string | Content; | ||
} | ||
@@ -104,6 +109,7 @@ /** | ||
requestOptions?: RequestOptions; | ||
/** Optional. The user provided system instructions for the model. | ||
/** | ||
* Optional. The user provided system instructions for the model. | ||
* Note: only text should be used in parts of {@link Content} | ||
*/ | ||
systemInstruction?: Content; | ||
systemInstruction?: string | Content; | ||
} | ||
@@ -131,6 +137,7 @@ /** | ||
tools?: Tool[]; | ||
/** Optional. The user provided system instructions for the model. | ||
/** | ||
* Optional. The user provided system instructions for the model. | ||
* Note: only text should be used in parts of {@link Content} | ||
*/ | ||
systemInstruction?: Content; | ||
systemInstruction?: string | Content; | ||
} | ||
@@ -525,8 +532,16 @@ /** | ||
export declare interface GoogleDate { | ||
/** Year of the date. Must be from 1 to 9999, or 0 to specify a date without a year. */ | ||
/** | ||
* Year of the date. Must be from 1 to 9999, or 0 to specify a date without a | ||
* year. | ||
*/ | ||
year?: number; | ||
/** Month of the date. Must be from 1 to 12, or 0 to specify a year without a monthi and day. */ | ||
/** | ||
* Month of the date. Must be from 1 to 12, or 0 to specify a year without a | ||
* monthi and day. | ||
*/ | ||
month?: number; | ||
/** Day of the date. Must be from 1 to 31 and valid for the year and month. | ||
* or 0 to specify a year by itself or a year and month where the day isn't significant | ||
/** | ||
* Day of the date. Must be from 1 to 31 and valid for the year and month. | ||
* or 0 to specify a year by itself or a year and month where the day isn't | ||
* significant | ||
*/ | ||
@@ -553,2 +568,11 @@ day?: number; | ||
/** | ||
* Google search entry point. | ||
*/ | ||
export declare interface SearchEntryPoint { | ||
/** Optional. Web content snippet that can be embedded in a web page or an app webview. */ | ||
renderedContent?: string; | ||
/** Optional. Base64 encoded JSON representing array of tuple. */ | ||
sdkBlob?: string; | ||
} | ||
/** | ||
* A collection of grounding attributions for a piece of content. | ||
@@ -563,2 +587,4 @@ */ | ||
groundingAttributions?: GroundingAttribution[]; | ||
/** Optional. Google search entry for the following-up web searches. {@link SearchEntryPoint} */ | ||
searchEntryPoint?: SearchEntryPoint; | ||
} | ||
@@ -704,3 +730,32 @@ /** | ||
} | ||
export declare interface VertexRagStore { | ||
/** | ||
* Optional. List of corpora for retrieval. Currently only support one corpus | ||
* or multiple files from one corpus. In the future we may open up multiple | ||
* corpora support. | ||
*/ | ||
ragResources?: RagResource[]; | ||
/** Optional. Number of top k results to return from the selected corpora. */ | ||
similarityTopK?: number; | ||
/** Optional. If set this field, results with vector distance smaller than this threshold will be returned. */ | ||
vectorDistanceThreshold?: number; | ||
} | ||
/** | ||
* Config of Vertex RagStore grounding checking. | ||
*/ | ||
export declare interface RagResource { | ||
/** | ||
* Optional. Vertex RAG Store corpus resource name. | ||
* | ||
* @example | ||
* `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}` | ||
*/ | ||
ragCorpus?: string; | ||
/** | ||
* Optional. Set this field to select the files under the ragCorpora for | ||
* retrieval. | ||
*/ | ||
ragFileIds?: string[]; | ||
} | ||
/** | ||
* Defines a retrieval tool that model can call to access external knowledge. | ||
@@ -723,2 +778,4 @@ */ | ||
vertexAiSearch?: VertexAISearch; | ||
/** Optional. Set to use data source powered by Vertex RAG store. */ | ||
vertexRagStore?: VertexRagStore; | ||
/** | ||
@@ -832,6 +889,7 @@ * Optional. Disable using the result from this tool in detecting grounding | ||
apiEndpoint?: string; | ||
/** Optional. The user provided system instructions for the model. | ||
/** | ||
* Optional. The user provided system instructions for the model. | ||
* Note: only text should be used in parts of {@link Content} | ||
*/ | ||
systemInstruction?: Content; | ||
systemInstruction?: string | Content; | ||
} | ||
@@ -850,6 +908,9 @@ /** | ||
publisherModelEndpoint: string; | ||
/** Optional. The user provided system instructions for the model. | ||
/** The resource path to use for the request. */ | ||
resourcePath: string; | ||
/** | ||
* Optional. The user provided system instructions for the model. | ||
* Note: only text should be used in parts of {@link Content} | ||
*/ | ||
systemInstruction?: Content; | ||
systemInstruction?: string | Content; | ||
} | ||
@@ -856,0 +917,0 @@ /** |
@@ -19,6 +19,7 @@ /** | ||
export declare const STREAMING_GENERATE_CONTENT_METHOD = "streamGenerateContent"; | ||
export declare const COUNT_TOKENS_METHOD = "countTokens"; | ||
export declare const USER_ROLE = "user"; | ||
export declare const MODEL_ROLE = "model"; | ||
export declare const SYSTEM_ROLE = "system"; | ||
export declare const USER_AGENT = "model-builder/1.1.0 grpc-node/1.1.0"; | ||
export declare const USER_AGENT = "model-builder/1.2.0 grpc-node/1.2.0"; | ||
export declare const CREDENTIAL_ERROR_MESSAGE = "\nUnable to authenticate your request \nDepending on your run time environment, you can get authentication by \n- if in local instance or cloud shell: `!gcloud auth login` \n- if in Colab: \n -`from google.colab import auth` \n -`auth.authenticate_user()` \n- if in service account or other: please follow guidance in https://cloud.google.com/docs/authentication"; |
"use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.CREDENTIAL_ERROR_MESSAGE = exports.USER_AGENT = exports.SYSTEM_ROLE = exports.MODEL_ROLE = exports.USER_ROLE = exports.STREAMING_GENERATE_CONTENT_METHOD = exports.GENERATE_CONTENT_METHOD = void 0; | ||
exports.CREDENTIAL_ERROR_MESSAGE = exports.USER_AGENT = exports.SYSTEM_ROLE = exports.MODEL_ROLE = exports.USER_ROLE = exports.COUNT_TOKENS_METHOD = exports.STREAMING_GENERATE_CONTENT_METHOD = exports.GENERATE_CONTENT_METHOD = void 0; | ||
/** | ||
@@ -22,2 +22,3 @@ * @license | ||
exports.STREAMING_GENERATE_CONTENT_METHOD = 'streamGenerateContent'; | ||
exports.COUNT_TOKENS_METHOD = 'countTokens'; | ||
exports.USER_ROLE = 'user'; | ||
@@ -27,3 +28,3 @@ exports.MODEL_ROLE = 'model'; | ||
const USER_AGENT_PRODUCT = 'model-builder'; | ||
const CLIENT_LIBRARY_VERSION = '1.1.0'; // x-release-please-version | ||
const CLIENT_LIBRARY_VERSION = '1.2.0'; // x-release-please-version | ||
const CLIENT_LIBRARY_LANGUAGE = `grpc-node/${CLIENT_LIBRARY_VERSION}`; | ||
@@ -30,0 +31,0 @@ exports.USER_AGENT = `${USER_AGENT_PRODUCT}/${CLIENT_LIBRARY_VERSION} ${CLIENT_LIBRARY_LANGUAGE}`; |
@@ -40,3 +40,3 @@ "use strict"; | ||
this.project = init.project; | ||
this.location = init.location; | ||
this.location = resolveLocation(init.location); | ||
this.googleAuth = new google_auth_library_1.GoogleAuth(opts); | ||
@@ -178,2 +178,12 @@ this.apiEndpoint = init.apiEndpoint; | ||
} | ||
function resolveLocation(locationFromInput) { | ||
if (locationFromInput) { | ||
return locationFromInput; | ||
} | ||
const inferredLocation = process.env['GOOGLE_CLOUD_REGION'] || process.env['CLOUD_ML_REGION']; | ||
if (inferredLocation) { | ||
return inferredLocation; | ||
} | ||
return 'us-central1'; | ||
} | ||
//# sourceMappingURL=vertex_ai.js.map |
@@ -22,3 +22,3 @@ "use strict"; | ||
const types_1 = require("../src/types"); | ||
const PROJECT = process.env.GCLOUD_PROJECT; | ||
const PROJECT = process.env['GCLOUD_PROJECT']; | ||
const LOCATION = 'us-central1'; | ||
@@ -79,2 +79,15 @@ const TEXT_REQUEST = { | ||
]; | ||
const TOOLS_WITH_RAG = [ | ||
{ | ||
retrieval: { | ||
vertexRagStore: { | ||
ragResources: [ | ||
{ | ||
ragCorpus: 'projects/ucaip-sample-tests/locations/us-central1/ragCorpora/6917529027641081856', | ||
}, | ||
], | ||
}, | ||
}, | ||
}, | ||
]; | ||
const WEATHER_FORECAST = 'super nice'; | ||
@@ -259,3 +272,3 @@ const FUNCTION_RESPONSE_PART = [ | ||
}); | ||
it('should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
var _a; | ||
@@ -277,3 +290,3 @@ const request = { | ||
}); | ||
it('in preview should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('in preview should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
var _a; | ||
@@ -295,3 +308,3 @@ const request = { | ||
}); | ||
it('should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
xit('should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
var _a; | ||
@@ -314,3 +327,3 @@ const request = { | ||
}); | ||
it('in preview should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
xit('in preview should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
var _a; | ||
@@ -397,2 +410,22 @@ const request = { | ||
}); | ||
it('in preview should return grounding metadata when passed a VertexRagStore', async () => { | ||
var _a, _b; | ||
const request = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [ | ||
{ | ||
text: 'How much gain or loss did Google get in the Motorola Mobile deal in 2014?', | ||
}, | ||
], | ||
}, | ||
], | ||
tools: TOOLS_WITH_RAG, | ||
}; | ||
const result = await generativeTextModelPreview.generateContentStream(request); | ||
const response = await result.response; | ||
expect(response.candidates[0]).toBeTruthy(`sys test failure on generateContent with RAG tool, for resp ${JSON.stringify(response)}`); | ||
expect((_b = (_a = response.candidates[0]) === null || _a === void 0 ? void 0 : _a.groundingMetadata) === null || _b === void 0 ? void 0 : _b.retrievalQueries).toBeTruthy(`sys test failure on generateContent with RAG tool, empty groundingMetadata.retrievalQueries, for resp ${JSON.stringify(response)}`); | ||
}); | ||
}); | ||
@@ -477,3 +510,3 @@ describe('generateContent', () => { | ||
}); | ||
it('should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
var _a; | ||
@@ -492,3 +525,3 @@ const request = { | ||
}); | ||
it('in preview should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('in preview should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
var _a; | ||
@@ -507,5 +540,24 @@ const request = { | ||
}); | ||
it('should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
it('in preview should return grounding metadata when passed a VertexRagStore', async () => { | ||
var _a, _b; | ||
const request = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [ | ||
{ | ||
text: 'How much gain or loss did Google get in the Motorola Mobile deal in 2014?', | ||
}, | ||
], | ||
}, | ||
], | ||
tools: TOOLS_WITH_RAG, | ||
}; | ||
const resp = await generativeTextModelPreview.generateContent(request); | ||
expect(resp.response.candidates[0]).toBeTruthy(`sys test failure on generateContent with RAG tool, for resp ${JSON.stringify(resp)}`); | ||
expect((_b = (_a = resp.response.candidates[0]) === null || _a === void 0 ? void 0 : _a.groundingMetadata) === null || _b === void 0 ? void 0 : _b.retrievalQueries).toBeTruthy(`sys test failure on generateContent with RAG tool, empty groundingMetadata.retrievalQueries, for resp ${JSON.stringify(resp)}`); | ||
}); | ||
xit('should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const request = { | ||
contents: [ | ||
{ role: 'user', parts: [{ text: 'What is the weather in Boston?' }] }, | ||
@@ -523,3 +575,3 @@ ], | ||
}); | ||
it('in preview should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
xit('in preview should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const request = { | ||
@@ -679,3 +731,3 @@ contents: [ | ||
}); | ||
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const chat = generativeTextModel.startChat({ | ||
@@ -700,3 +752,3 @@ tools: TOOLS_WITH_FUNCTION_DECLARATION, | ||
}); | ||
it('in preview should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('in preview should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const chat = generativeTextModelPreview.startChat({ | ||
@@ -703,0 +755,0 @@ tools: TOOLS_WITH_FUNCTION_DECLARATION, |
@@ -16,2 +16,40 @@ "use strict"; | ||
}); | ||
it('no location given, should instantiate VertexAI and VertexAIPreview', () => { | ||
const vertexaiNoLocation = new vertex_ai_1.VertexAI({ project: PROJECT }); | ||
const generativeModel = vertexaiNoLocation.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeModelPreview = vertexaiNoLocation.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
expect(vertexaiNoLocation).toBeInstanceOf(vertex_ai_1.VertexAI); | ||
expect(generativeModel).toBeInstanceOf(models_1.GenerativeModel); | ||
expect(generativeModelPreview).toBeInstanceOf(models_1.GenerativeModelPreview); | ||
}); | ||
it('location in run time env GOOGLE_CLOUD_REGION, should instantiate VertexAI and VertexAIPreview', () => { | ||
process.env['GOOGLE_CLOUD_REGION'] = 'us-central1'; | ||
const vertexaiNoLocation = new vertex_ai_1.VertexAI({ project: PROJECT }); | ||
const generativeModel = vertexaiNoLocation.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeModelPreview = vertexaiNoLocation.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
expect(vertexaiNoLocation).toBeInstanceOf(vertex_ai_1.VertexAI); | ||
expect(generativeModel).toBeInstanceOf(models_1.GenerativeModel); | ||
expect(generativeModelPreview).toBeInstanceOf(models_1.GenerativeModelPreview); | ||
}); | ||
it('location in run time env CLOUD_ML_REGION, should instantiate VertexAI and VertexAIPreview', () => { | ||
process.env['CLOUD_ML_REGION'] = 'us-central1'; | ||
const vertexaiNoLocation = new vertex_ai_1.VertexAI({ project: PROJECT }); | ||
const generativeModel = vertexaiNoLocation.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeModelPreview = vertexaiNoLocation.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
expect(vertexaiNoLocation).toBeInstanceOf(vertex_ai_1.VertexAI); | ||
expect(generativeModel).toBeInstanceOf(models_1.GenerativeModel); | ||
expect(generativeModelPreview).toBeInstanceOf(models_1.GenerativeModelPreview); | ||
}); | ||
it('given undefined google auth options, should be instantiated', () => { | ||
@@ -18,0 +56,0 @@ expect(vertexai).toBeInstanceOf(vertex_ai_1.VertexAI); |
# Changelog | ||
## [1.2.0](https://github.com/googleapis/nodejs-vertexai/compare/v1.1.0...v1.2.0) (2024-05-22) | ||
### Features | ||
* allow users to pass string as system instruction ([a824162](https://github.com/googleapis/nodejs-vertexai/commit/a824162a29b8c6ba3ccae46c492dd28e9c2baf9c)) | ||
* enable inference request to tuned model. ([de9c4c2](https://github.com/googleapis/nodejs-vertexai/commit/de9c4c2f8c63a298bd28ab69dae9b6a5d72c20d7)) | ||
* infer location if user doesn't specifies it. ([b8d4af1](https://github.com/googleapis/nodejs-vertexai/commit/b8d4af1bb990e95093f446c808194bfc4fe53287)) | ||
* support RAG in public preview ([5ade755](https://github.com/googleapis/nodejs-vertexai/commit/5ade7551fe0dbab54bc56f251c32bf3b7802c2c5)) | ||
* update grounding metadata ([d3c0a64](https://github.com/googleapis/nodejs-vertexai/commit/d3c0a647248be6be49b6b93a18aad79c10bae6c4)) | ||
### Bug Fixes | ||
* log instead of throw appendHistory errors to avoid unhandled rejection ([2ec9e7d](https://github.com/googleapis/nodejs-vertexai/commit/2ec9e7d5519af438eb03b9f21f90b86f2575ac47)) | ||
## [1.1.0](https://github.com/googleapis/nodejs-vertexai/compare/v1.0.0...v1.1.0) (2024-04-13) | ||
@@ -4,0 +20,0 @@ |
{ | ||
"name": "@google-cloud/vertexai", | ||
"description": "Vertex Generative AI client for Node.js", | ||
"version": "1.1.0", | ||
"version": "1.2.0", | ||
"license": "Apache-2.0", | ||
@@ -6,0 +6,0 @@ "author": "Google LLC", |
@@ -18,2 +18,3 @@ [![NPM Downloads](https://img.shields.io/npm/dm/%40google-cloud%2Fvertexai)](https://www.npmjs.com/package/@google-cloud/vertexai) | ||
1. Make sure your node.js version is 18 or above. | ||
1. [Select](https://console.cloud.google.com/project) or [create](https://cloud.google.com/resource-manager/docs/creating-managing-projects#creating_a_project) a Google Cloud project. | ||
@@ -20,0 +21,0 @@ 1. [Enable billing for your project](https://cloud.google.com/billing/docs/how-to/modify-project). |
@@ -24,2 +24,3 @@ /** | ||
import {GoogleGenerativeAIError} from '../types/errors'; | ||
import * as constants from '../util/constants'; | ||
import { | ||
@@ -38,4 +39,3 @@ throwErrorIfNotOK, | ||
location: string, | ||
project: string, | ||
publisherModelEndpoint: string, | ||
resourcePath: string, | ||
token: Promise<string | null | undefined>, | ||
@@ -48,5 +48,4 @@ request: CountTokensRequest, | ||
region: location, | ||
project: project, | ||
resourcePath: publisherModelEndpoint, | ||
resourceMethod: 'countTokens', | ||
resourcePath: resourcePath, | ||
resourceMethod: constants.COUNT_TOKENS_METHOD, | ||
token: await token, | ||
@@ -53,0 +52,0 @@ data: request, |
@@ -46,2 +46,4 @@ /** | ||
validateGenerationConfig, | ||
hasVertexRagStore, | ||
getApiVersion, | ||
} from './pre_fetch_processing'; | ||
@@ -51,4 +53,3 @@ | ||
location: string, | ||
project: string, | ||
publisherModelEndpoint: string, | ||
resourcePath: string, | ||
token: Promise<string | null | undefined>, | ||
@@ -81,9 +82,9 @@ request: GenerateContentRequest | string, | ||
region: location, | ||
project: project, | ||
resourcePath: publisherModelEndpoint, | ||
resourcePath, | ||
resourceMethod: constants.GENERATE_CONTENT_METHOD, | ||
token: await token, | ||
data: generateContentRequest, | ||
apiEndpoint: apiEndpoint, | ||
requestOptions: requestOptions, | ||
apiEndpoint, | ||
requestOptions, | ||
apiVersion: getApiVersion(request), | ||
}).catch(e => { | ||
@@ -107,4 +108,3 @@ throw new GoogleGenerativeAIError('exception posting request to model', e); | ||
location: string, | ||
project: string, | ||
publisherModelEndpoint: string, | ||
resourcePath: string, | ||
token: Promise<string | null | undefined>, | ||
@@ -136,9 +136,9 @@ request: GenerateContentRequest | string, | ||
region: location, | ||
project: project, | ||
resourcePath: publisherModelEndpoint, | ||
resourcePath, | ||
resourceMethod: constants.STREAMING_GENERATE_CONTENT_METHOD, | ||
token: await token, | ||
data: generateContentRequest, | ||
apiEndpoint: apiEndpoint, | ||
requestOptions: requestOptions, | ||
apiEndpoint, | ||
requestOptions, | ||
apiVersion: getApiVersion(request), | ||
}).catch(e => { | ||
@@ -145,0 +145,0 @@ throw new GoogleGenerativeAIError('exception posting request', e); |
@@ -26,3 +26,2 @@ /** | ||
GroundingMetadata, | ||
Part, | ||
StreamGenerateContentResult, | ||
@@ -301,2 +300,3 @@ } from '../types/content'; | ||
groundingAttributions: [], | ||
retrievalQueries: [], | ||
}; | ||
@@ -319,2 +319,12 @@ const groundingMetadataAggregated: GroundingMetadata = | ||
} | ||
if (groundingMetadataChunk.retrievalQueries) { | ||
groundingMetadataAggregated.retrievalQueries = | ||
groundingMetadataAggregated.retrievalQueries!.concat( | ||
groundingMetadataChunk.retrievalQueries | ||
); | ||
} | ||
if (groundingMetadataChunk.searchEntryPoint) { | ||
groundingMetadataAggregated.searchEntryPoint = | ||
groundingMetadataChunk.searchEntryPoint; | ||
} | ||
return groundingMetadataAggregated; | ||
@@ -321,0 +331,0 @@ } |
@@ -41,3 +41,2 @@ /** | ||
region, | ||
project, | ||
resourcePath, | ||
@@ -52,3 +51,2 @@ resourceMethod, | ||
region: string; | ||
project: string; | ||
resourcePath: string; | ||
@@ -64,3 +62,3 @@ resourceMethod: string; | ||
let vertexEndpoint = `https://${vertexBaseEndpoint}/${apiVersion}/projects/${project}/locations/${region}/${resourcePath}:${resourceMethod}`; | ||
let vertexEndpoint = `https://${vertexBaseEndpoint}/${apiVersion}/${resourcePath}:${resourceMethod}`; | ||
@@ -67,0 +65,0 @@ // Use server sent events for streamGenerateContent |
@@ -21,4 +21,7 @@ /** | ||
GenerationConfig, | ||
RetrievalTool, | ||
SafetySetting, | ||
Tool, | ||
} from '../types/content'; | ||
import {ClientError} from '../types/errors'; | ||
import * as constants from '../util/constants'; | ||
@@ -59,2 +62,8 @@ | ||
} | ||
if (hasVertexAISearch(request) && hasVertexRagStore(request)) { | ||
throw new ClientError( | ||
'Found both vertexAiSearch and vertexRagStore field are set in tool. Either set vertexAiSearch or vertexRagStore.' | ||
); | ||
} | ||
} | ||
@@ -72,1 +81,29 @@ | ||
} | ||
export function getApiVersion( | ||
request: GenerateContentRequest | ||
): 'v1' | 'v1beta1' { | ||
return hasVertexRagStore(request) ? 'v1beta1' : 'v1'; | ||
} | ||
export function hasVertexRagStore(request: GenerateContentRequest): boolean { | ||
for (const tool of request?.tools ?? []) { | ||
const retrieval = (tool as RetrievalTool).retrieval; | ||
if (!retrieval) continue; | ||
if (retrieval.vertexRagStore) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
export function hasVertexAISearch(request: GenerateContentRequest): boolean { | ||
for (const tool of request?.tools ?? []) { | ||
const retrieval = (tool as RetrievalTool).retrieval; | ||
if (!retrieval) continue; | ||
if (retrieval.vertexAiSearch) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} |
@@ -40,5 +40,4 @@ /** | ||
const TEST_PROJECT = 'test-project'; | ||
const TEST_LOCATION = 'test-location'; | ||
const TEST_PUBLISHER_MODEL_ENDPOINT = 'test-publisher-model-endpoint'; | ||
const TEST_RESOURCE_PATH = 'test-resource-path'; | ||
const TEST_TOKEN = 'testtoken'; | ||
@@ -52,2 +51,9 @@ const TEST_TOKEN_PROMISE = Promise.resolve(TEST_TOKEN); | ||
const CONTENTS = [ | ||
{ | ||
role: 'user', | ||
parts: [{text: 'What is the weater like in Boston?'}], | ||
}, | ||
]; | ||
const TEST_USER_CHAT_MESSAGE_WITH_GCS_FILE = [ | ||
@@ -214,2 +220,8 @@ { | ||
const TEST_TOOLS_WITH_RAG: Tool[] = [ | ||
{ | ||
retrieval: {vertexRagStore: {ragResources: [{ragCorpus: 'ragCorpus'}]}}, | ||
}, | ||
]; | ||
const fetchResponseObj = { | ||
@@ -255,4 +267,3 @@ status: 200, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -275,4 +286,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -304,4 +314,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -331,4 +340,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -361,4 +369,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -385,4 +392,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -401,4 +407,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -421,4 +426,3 @@ TEST_CHAT_MESSAGE_TEXT, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -438,4 +442,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -460,4 +463,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -476,4 +478,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -497,4 +498,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -519,4 +519,3 @@ reqWithEmptyConfigs, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -539,4 +538,3 @@ reqWithEmptyConfigs, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -555,5 +553,3 @@ req, | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weater like in Boston?'}]}, | ||
], | ||
contents: CONTENTS, | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
@@ -573,4 +569,3 @@ }; | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -597,8 +592,3 @@ req, | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{text: 'What is the weater like in Boston?'}], | ||
}, | ||
], | ||
contents: CONTENTS, | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
@@ -618,4 +608,3 @@ }; | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -635,8 +624,3 @@ req, | ||
const req: GenerateContentRequest = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [{text: 'What is the weater like in Boston?'}], | ||
}, | ||
], | ||
contents: CONTENTS, | ||
tools: TEST_TOOLS_WITH_FUNCTION_DECLARATION, | ||
@@ -648,4 +632,3 @@ }; | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -657,2 +640,35 @@ req, | ||
}); | ||
it('should use v1 apiVersion', async () => { | ||
const request: GenerateContentRequest = { | ||
contents: CONTENTS, | ||
}; | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
await generateContent( | ||
TEST_LOCATION, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
request, | ||
TEST_API_ENDPOINT | ||
); | ||
const vertexEndpoint = fetchSpy.calls.allArgs()[0][0]; | ||
expect(vertexEndpoint).toContain('/v1/'); | ||
}); | ||
it('should use v1beta1 apiVersion when set RAG in tools', async () => { | ||
const request: GenerateContentRequest = { | ||
contents: CONTENTS, | ||
tools: TEST_TOOLS_WITH_RAG, | ||
}; | ||
fetchSpy.and.resolveTo(buildFetchResponse(TEST_MODEL_RESPONSE)); | ||
await generateContent( | ||
TEST_LOCATION, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
request, | ||
TEST_API_ENDPOINT | ||
); | ||
const vertexEndpoint = fetchSpy.calls.allArgs()[0][0]; | ||
expect(vertexEndpoint).toContain('/v1beta1/'); | ||
}); | ||
}); | ||
@@ -688,4 +704,3 @@ | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -714,4 +729,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -733,4 +747,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -755,4 +768,3 @@ TEST_API_ENDPOINT, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -777,4 +789,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -801,4 +812,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -838,4 +848,3 @@ req, | ||
TEST_LOCATION, | ||
TEST_PROJECT, | ||
TEST_PUBLISHER_MODEL_ENDPOINT, | ||
TEST_RESOURCE_PATH, | ||
TEST_TOKEN_PROMISE, | ||
@@ -842,0 +851,0 @@ req, |
@@ -43,4 +43,3 @@ /** | ||
const LOCATION = 'location'; | ||
const PROJECT = 'project'; | ||
const PUBLISHER_MODEL_ENDPOINT = 'publisher_model_endpoint'; | ||
const RESOURCE_PATH = 'RESOURCE_PATH'; | ||
const TOKEN = Promise.resolve('token'); | ||
@@ -90,4 +89,3 @@ const GENERATE_CONTENT_REQUEST = 'generate_content_request'; | ||
LOCATION, | ||
PROJECT, | ||
PUBLISHER_MODEL_ENDPOINT, | ||
RESOURCE_PATH, | ||
TOKEN, | ||
@@ -110,4 +108,3 @@ GENERATE_CONTENT_REQUEST | ||
LOCATION, | ||
PROJECT, | ||
PUBLISHER_MODEL_ENDPOINT, | ||
RESOURCE_PATH, | ||
TOKEN, | ||
@@ -127,4 +124,3 @@ GENERATE_CONTENT_REQUEST | ||
LOCATION, | ||
PROJECT, | ||
PUBLISHER_MODEL_ENDPOINT, | ||
RESOURCE_PATH, | ||
TOKEN, | ||
@@ -149,4 +145,3 @@ GENERATE_CONTENT_REQUEST | ||
LOCATION, | ||
PROJECT, | ||
PUBLISHER_MODEL_ENDPOINT, | ||
RESOURCE_PATH, | ||
TOKEN, | ||
@@ -172,4 +167,3 @@ GENERATE_CONTENT_REQUEST | ||
LOCATION, | ||
PROJECT, | ||
PUBLISHER_MODEL_ENDPOINT, | ||
RESOURCE_PATH, | ||
TOKEN, | ||
@@ -176,0 +170,0 @@ COUNT_TOKEN_REQUEST |
@@ -6,3 +6,2 @@ import {GenerateContentRequest, RequestOptions} from '../../types'; | ||
const REGION = 'us-central1'; | ||
const PROJECT = 'project-id'; | ||
const RESOURCE_PATH = 'resource-path'; | ||
@@ -28,3 +27,2 @@ const RESOURCE_METHOD = 'resource-method'; | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -49,3 +47,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -67,3 +64,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -88,3 +84,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -109,3 +104,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -127,3 +121,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -150,3 +143,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -179,3 +171,2 @@ resourceMethod: RESOURCE_METHOD, | ||
region: REGION, | ||
project: PROJECT, | ||
resourcePath: RESOURCE_PATH, | ||
@@ -182,0 +173,0 @@ resourceMethod: RESOURCE_METHOD, |
@@ -326,2 +326,6 @@ /** | ||
], | ||
searchEntryPoint: { | ||
renderedContent: | ||
'rendered content for later chunk for first candidate', | ||
}, | ||
}, | ||
@@ -348,2 +352,6 @@ }, | ||
], | ||
searchEntryPoint: { | ||
renderedContent: | ||
'rendered content for later chunk for second candidate', | ||
}, | ||
}, | ||
@@ -435,2 +443,7 @@ }, | ||
], | ||
retrievalQueries: [], | ||
searchEntryPoint: { | ||
renderedContent: | ||
'rendered content for later chunk for first candidate', | ||
}, | ||
}, | ||
@@ -506,2 +519,7 @@ finishReason: 'STOP', | ||
], | ||
retrievalQueries: [], | ||
searchEntryPoint: { | ||
renderedContent: | ||
'rendered content for later chunk for second candidate', | ||
}, | ||
}, | ||
@@ -508,0 +526,0 @@ finishReason: 'STOP', |
@@ -21,2 +21,3 @@ /** | ||
import {formulateSystemInstructionIntoContent} from './util'; | ||
import { | ||
@@ -38,7 +39,3 @@ generateContent, | ||
} from '../types/content'; | ||
import { | ||
ClientError, | ||
GoogleAuthError, | ||
GoogleGenerativeAIError, | ||
} from '../types/errors'; | ||
import {ClientError, GoogleAuthError} from '../types/errors'; | ||
import {constants} from '../util'; | ||
@@ -59,3 +56,3 @@ | ||
private sendStreamPromise: Promise<void> = Promise.resolve(); | ||
private readonly publisherModelEndpoint: string; | ||
private readonly resourcePath: string; | ||
private readonly googleAuth: GoogleAuth; | ||
@@ -84,3 +81,3 @@ protected readonly requestOptions?: RequestOptions; | ||
this.googleAuth = request.googleAuth; | ||
this.publisherModelEndpoint = request.publisherModelEndpoint; | ||
this.resourcePath = request.resourcePath; | ||
this.historyInternal = request.history ?? []; | ||
@@ -93,5 +90,6 @@ this.generationConfig = request.generationConfig; | ||
if (request.systemInstruction) { | ||
request.systemInstruction.role = constants.SYSTEM_ROLE; | ||
this.systemInstruction = formulateSystemInstructionIntoContent( | ||
request.systemInstruction | ||
); | ||
} | ||
this.systemInstruction = request.systemInstruction; | ||
} | ||
@@ -145,4 +143,3 @@ | ||
this.location, | ||
this.project, | ||
this.publisherModelEndpoint, | ||
this.resourcePath, | ||
this.fetchToken(), | ||
@@ -230,4 +227,3 @@ generateContentrequest, | ||
this.location, | ||
this.project, | ||
this.publisherModelEndpoint, | ||
this.resourcePath, | ||
this.fetchToken(), | ||
@@ -248,3 +244,5 @@ generateContentrequest, | ||
).catch(e => { | ||
throw new GoogleGenerativeAIError('exception appending chat history', e); | ||
// Errors from remote endpoint will be catchable by user from streamGenerateContentResultPromise | ||
// Errors in appendHistory should not throw to cause user's programe exit with code 1 | ||
console.error(e); | ||
}); | ||
@@ -268,3 +266,3 @@ return streamGenerateContentResultPromise; | ||
private sendStreamPromise: Promise<void> = Promise.resolve(); | ||
private readonly publisherModelEndpoint: string; | ||
private readonly resourcePath: string; | ||
private readonly googleAuth: GoogleAuth; | ||
@@ -293,3 +291,3 @@ protected readonly requestOptions?: RequestOptions; | ||
this.googleAuth = request.googleAuth; | ||
this.publisherModelEndpoint = request.publisherModelEndpoint; | ||
this.resourcePath = request.resourcePath; | ||
this.historyInternal = request.history ?? []; | ||
@@ -302,5 +300,6 @@ this.generationConfig = request.generationConfig; | ||
if (request.systemInstruction) { | ||
request.systemInstruction.role = constants.SYSTEM_ROLE; | ||
this.systemInstruction = formulateSystemInstructionIntoContent( | ||
request.systemInstruction | ||
); | ||
} | ||
this.systemInstruction = request.systemInstruction; | ||
} | ||
@@ -353,4 +352,3 @@ | ||
this.location, | ||
this.project, | ||
this.publisherModelEndpoint, | ||
this.resourcePath, | ||
this.fetchToken(), | ||
@@ -439,4 +437,3 @@ generateContentrequest, | ||
this.location, | ||
this.project, | ||
this.publisherModelEndpoint, | ||
this.resourcePath, | ||
this.fetchToken(), | ||
@@ -456,3 +453,7 @@ generateContentrequest, | ||
newContent | ||
); | ||
).catch(e => { | ||
// Errors from remote endpoint will be catchable by user from streamGenerateContentResultPromise | ||
// Errors in appendHistory should not throw to cause user's programe exit with code 1 | ||
console.error(e); | ||
}); | ||
return streamGenerateContentResultPromise; | ||
@@ -459,0 +460,0 @@ } |
@@ -21,2 +21,3 @@ /** | ||
import {formulateSystemInstructionIntoContent} from './util'; | ||
import {countTokens} from '../functions/count_tokens'; | ||
@@ -42,3 +43,3 @@ import { | ||
} from '../types/content'; | ||
import {GoogleAuthError} from '../types/errors'; | ||
import {ClientError, GoogleAuthError} from '../types/errors'; | ||
import {constants} from '../util'; | ||
@@ -65,2 +66,3 @@ | ||
private readonly publisherModelEndpoint: string; | ||
private readonly resourcePath: string; | ||
private readonly apiEndpoint?: string; | ||
@@ -83,10 +85,13 @@ | ||
if (getGenerativeModelParams.systemInstruction) { | ||
getGenerativeModelParams.systemInstruction.role = constants.SYSTEM_ROLE; | ||
this.systemInstruction = formulateSystemInstructionIntoContent( | ||
getGenerativeModelParams.systemInstruction | ||
); | ||
} | ||
this.systemInstruction = getGenerativeModelParams.systemInstruction; | ||
if (this.model.startsWith('models/')) { | ||
this.publisherModelEndpoint = `publishers/google/${this.model}`; | ||
} else { | ||
this.publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
} | ||
this.resourcePath = formulateResourcePathFromModel( | ||
this.model, | ||
this.project, | ||
this.location | ||
); | ||
// publisherModelEndpoint is deprecated | ||
this.publisherModelEndpoint = this.resourcePath; | ||
} | ||
@@ -135,4 +140,3 @@ | ||
this.location, | ||
this.project, | ||
this.publisherModelEndpoint, | ||
this.resourcePath, | ||
this.fetchToken(), | ||
@@ -183,4 +187,3 @@ formulatedRequest, | ||
this.location, | ||
this.project, | ||
this.publisherModelEndpoint, | ||
this.resourcePath, | ||
this.fetchToken(), | ||
@@ -217,4 +220,3 @@ formulatedRequest, | ||
this.location, | ||
this.project, | ||
this.publisherModelEndpoint, | ||
this.resourcePath, | ||
this.fetchToken(), | ||
@@ -257,3 +259,5 @@ request, | ||
publisherModelEndpoint: this.publisherModelEndpoint, | ||
resourcePath: this.resourcePath, | ||
tools: this.tools, | ||
systemInstruction: this.systemInstruction, | ||
}; | ||
@@ -293,2 +297,3 @@ | ||
private readonly publisherModelEndpoint: string; | ||
private readonly resourcePath: string; | ||
private readonly apiEndpoint?: string; | ||
@@ -311,10 +316,13 @@ | ||
if (getGenerativeModelParams.systemInstruction) { | ||
getGenerativeModelParams.systemInstruction.role = constants.SYSTEM_ROLE; | ||
this.systemInstruction = formulateSystemInstructionIntoContent( | ||
getGenerativeModelParams.systemInstruction | ||
); | ||
} | ||
this.systemInstruction = getGenerativeModelParams.systemInstruction; | ||
if (this.model.startsWith('models/')) { | ||
this.publisherModelEndpoint = `publishers/google/${this.model}`; | ||
} else { | ||
this.publisherModelEndpoint = `publishers/google/models/${this.model}`; | ||
} | ||
this.resourcePath = formulateResourcePathFromModel( | ||
this.model, | ||
this.project, | ||
this.location | ||
); | ||
// publisherModelEndpoint is deprecated | ||
this.publisherModelEndpoint = this.resourcePath; | ||
} | ||
@@ -362,4 +370,3 @@ | ||
this.location, | ||
this.project, | ||
this.publisherModelEndpoint, | ||
this.resourcePath, | ||
this.fetchToken(), | ||
@@ -410,4 +417,3 @@ formulatedRequest, | ||
this.location, | ||
this.project, | ||
this.publisherModelEndpoint, | ||
this.resourcePath, | ||
this.fetchToken(), | ||
@@ -444,4 +450,3 @@ formulatedRequest, | ||
this.location, | ||
this.project, | ||
this.publisherModelEndpoint, | ||
this.resourcePath, | ||
this.fetchToken(), | ||
@@ -484,3 +489,5 @@ request, | ||
publisherModelEndpoint: this.publisherModelEndpoint, | ||
resourcePath: this.resourcePath, | ||
tools: this.tools, | ||
systemInstruction: this.systemInstruction, | ||
}; | ||
@@ -502,2 +509,29 @@ | ||
function formulateResourcePathFromModel( | ||
model: string, | ||
project: string, | ||
location: string | ||
): string { | ||
let resourcePath: string; | ||
if (!model) { | ||
throw new ClientError('model parameter must not be empty.'); | ||
} | ||
if (!model.includes('/')) { | ||
// example 'gemini-1.0-pro' | ||
resourcePath = `projects/${project}/locations/${location}/publishers/google/models/${model}`; | ||
} else if (model.startsWith('models/')) { | ||
// example 'models/gemini-1.0-pro' | ||
resourcePath = `projects/${project}/locations/${location}/publishers/google/${model}`; | ||
} else if (model.startsWith('projects/')) { | ||
// example 'projects/my-project/locations/my-location/models/my-tuned-model' | ||
resourcePath = model; | ||
} else { | ||
throw new ClientError( | ||
'model parameter must be either a Model Garden model ID or a full resource name.' | ||
); | ||
} | ||
return resourcePath; | ||
} | ||
function formulateRequestToGenerateContentRequest( | ||
@@ -519,3 +553,5 @@ request: GenerateContentRequest | string | ||
if (methodRequest.systemInstruction) { | ||
methodRequest.systemInstruction.role = constants.SYSTEM_ROLE; | ||
methodRequest.systemInstruction = formulateSystemInstructionIntoContent( | ||
methodRequest.systemInstruction | ||
); | ||
return methodRequest; | ||
@@ -522,0 +558,0 @@ } |
@@ -27,5 +27,9 @@ /** | ||
project: string; | ||
/** The Google Cloud project location. */ | ||
location: string; | ||
/** | ||
* Optional. The Google Cloud project location. If not provided, SDK will | ||
* firtly try to resolve it from run time environment. If no location resolved | ||
* from run time environment, SDK will use default value `us-central1`. | ||
*/ | ||
location?: string; | ||
/** | ||
* Optional. The base Vertex AI endpoint to use for the request. If not | ||
@@ -51,6 +55,7 @@ * provided, the default regionalized endpoint (i.e. | ||
contents: Content[]; | ||
/** Optional. The user provided system instructions for the model. | ||
/** | ||
* Optional. The user provided system instructions for the model. | ||
* Note: only text should be used in parts of {@link Content} | ||
*/ | ||
systemInstruction?: Content; | ||
systemInstruction?: string | Content; | ||
} | ||
@@ -111,6 +116,7 @@ | ||
requestOptions?: RequestOptions; | ||
/** Optional. The user provided system instructions for the model. | ||
/** | ||
* Optional. The user provided system instructions for the model. | ||
* Note: only text should be used in parts of {@link Content} | ||
*/ | ||
systemInstruction?: Content; | ||
systemInstruction?: string | Content; | ||
} | ||
@@ -140,6 +146,7 @@ | ||
tools?: Tool[]; | ||
/** Optional. The user provided system instructions for the model. | ||
/** | ||
* Optional. The user provided system instructions for the model. | ||
* Note: only text should be used in parts of {@link Content} | ||
*/ | ||
systemInstruction?: Content; | ||
systemInstruction?: string | Content; | ||
} | ||
@@ -566,8 +573,16 @@ | ||
export declare interface GoogleDate { | ||
/** Year of the date. Must be from 1 to 9999, or 0 to specify a date without a year. */ | ||
/** | ||
* Year of the date. Must be from 1 to 9999, or 0 to specify a date without a | ||
* year. | ||
*/ | ||
year?: number; | ||
/** Month of the date. Must be from 1 to 12, or 0 to specify a year without a monthi and day. */ | ||
/** | ||
* Month of the date. Must be from 1 to 12, or 0 to specify a year without a | ||
* monthi and day. | ||
*/ | ||
month?: number; | ||
/** Day of the date. Must be from 1 to 31 and valid for the year and month. | ||
* or 0 to specify a year by itself or a year and month where the day isn't significant | ||
/** | ||
* Day of the date. Must be from 1 to 31 and valid for the year and month. | ||
* or 0 to specify a year by itself or a year and month where the day isn't | ||
* significant | ||
*/ | ||
@@ -596,2 +611,12 @@ day?: number; | ||
/** | ||
* Google search entry point. | ||
*/ | ||
export declare interface SearchEntryPoint { | ||
/** Optional. Web content snippet that can be embedded in a web page or an app webview. */ | ||
renderedContent?: string; | ||
/** Optional. Base64 encoded JSON representing array of tuple. */ | ||
sdkBlob?: string; | ||
} | ||
/** | ||
* A collection of grounding attributions for a piece of content. | ||
@@ -606,2 +631,4 @@ */ | ||
groundingAttributions?: GroundingAttribution[]; | ||
/** Optional. Google search entry for the following-up web searches. {@link SearchEntryPoint} */ | ||
searchEntryPoint?: SearchEntryPoint; | ||
} | ||
@@ -757,3 +784,37 @@ | ||
export declare interface VertexRagStore { | ||
/** | ||
* Optional. List of corpora for retrieval. Currently only support one corpus | ||
* or multiple files from one corpus. In the future we may open up multiple | ||
* corpora support. | ||
*/ | ||
ragResources?: RagResource[]; | ||
/** Optional. Number of top k results to return from the selected corpora. */ | ||
similarityTopK?: number; | ||
/** Optional. If set this field, results with vector distance smaller than this threshold will be returned. */ | ||
vectorDistanceThreshold?: number; | ||
} | ||
/** | ||
* Config of Vertex RagStore grounding checking. | ||
*/ | ||
export declare interface RagResource { | ||
/** | ||
* Optional. Vertex RAG Store corpus resource name. | ||
* | ||
* @example | ||
* `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}` | ||
*/ | ||
ragCorpus?: string; | ||
/** | ||
* Optional. Set this field to select the files under the ragCorpora for | ||
* retrieval. | ||
*/ | ||
ragFileIds?: string[]; | ||
} | ||
/** | ||
* Defines a retrieval tool that model can call to access external knowledge. | ||
@@ -781,2 +842,6 @@ */ | ||
vertexAiSearch?: VertexAISearch; | ||
/** Optional. Set to use data source powered by Vertex RAG store. */ | ||
vertexRagStore?: VertexRagStore; | ||
/** | ||
@@ -892,6 +957,7 @@ * Optional. Disable using the result from this tool in detecting grounding | ||
apiEndpoint?: string; | ||
/** Optional. The user provided system instructions for the model. | ||
/** | ||
* Optional. The user provided system instructions for the model. | ||
* Note: only text should be used in parts of {@link Content} | ||
*/ | ||
systemInstruction?: Content; | ||
systemInstruction?: string | Content; | ||
} | ||
@@ -911,6 +977,9 @@ | ||
publisherModelEndpoint: string; | ||
/** Optional. The user provided system instructions for the model. | ||
/** The resource path to use for the request. */ | ||
resourcePath: string; | ||
/** | ||
* Optional. The user provided system instructions for the model. | ||
* Note: only text should be used in parts of {@link Content} | ||
*/ | ||
systemInstruction?: Content; | ||
systemInstruction?: string | Content; | ||
} | ||
@@ -917,0 +986,0 @@ |
@@ -19,2 +19,3 @@ /** | ||
export const STREAMING_GENERATE_CONTENT_METHOD = 'streamGenerateContent'; | ||
export const COUNT_TOKENS_METHOD = 'countTokens'; | ||
export const USER_ROLE = 'user'; | ||
@@ -24,3 +25,3 @@ export const MODEL_ROLE = 'model'; | ||
const USER_AGENT_PRODUCT = 'model-builder'; | ||
const CLIENT_LIBRARY_VERSION = '1.1.0'; // x-release-please-version | ||
const CLIENT_LIBRARY_VERSION = '1.2.0'; // x-release-please-version | ||
const CLIENT_LIBRARY_LANGUAGE = `grpc-node/${CLIENT_LIBRARY_VERSION}`; | ||
@@ -27,0 +28,0 @@ export const USER_AGENT = `${USER_AGENT_PRODUCT}/${CLIENT_LIBRARY_VERSION} ${CLIENT_LIBRARY_LANGUAGE}`; |
@@ -55,3 +55,3 @@ /** | ||
this.project = init.project; | ||
this.location = init.location; | ||
this.location = resolveLocation(init.location); | ||
this.googleAuth = new GoogleAuth(opts); | ||
@@ -226,1 +226,13 @@ this.apiEndpoint = init.apiEndpoint; | ||
} | ||
function resolveLocation(locationFromInput?: string): string { | ||
if (locationFromInput) { | ||
return locationFromInput; | ||
} | ||
const inferredLocation = | ||
process.env['GOOGLE_CLOUD_REGION'] || process.env['CLOUD_ML_REGION']; | ||
if (inferredLocation) { | ||
return inferredLocation; | ||
} | ||
return 'us-central1'; | ||
} |
@@ -30,3 +30,3 @@ /** | ||
const PROJECT = process.env.GCLOUD_PROJECT; | ||
const PROJECT = process.env['GCLOUD_PROJECT']; | ||
const LOCATION = 'us-central1'; | ||
@@ -95,2 +95,17 @@ const TEXT_REQUEST = { | ||
const TOOLS_WITH_RAG = [ | ||
{ | ||
retrieval: { | ||
vertexRagStore: { | ||
ragResources: [ | ||
{ | ||
ragCorpus: | ||
'projects/ucaip-sample-tests/locations/us-central1/ragCorpora/6917529027641081856', | ||
}, | ||
], | ||
}, | ||
}, | ||
}, | ||
]; | ||
const WEATHER_FORECAST = 'super nice'; | ||
@@ -403,3 +418,3 @@ const FUNCTION_RESPONSE_PART = [ | ||
it('should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const request = { | ||
@@ -432,3 +447,3 @@ contents: [ | ||
}); | ||
it('in preview should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('in preview should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const request = { | ||
@@ -461,3 +476,3 @@ contents: [ | ||
}); | ||
it('should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
xit('should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const request = { | ||
@@ -488,3 +503,3 @@ contents: [ | ||
}); | ||
it('in preview should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
xit('in preview should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const request = { | ||
@@ -603,2 +618,33 @@ contents: [ | ||
}); | ||
it('in preview should return grounding metadata when passed a VertexRagStore', async () => { | ||
const request = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [ | ||
{ | ||
text: 'How much gain or loss did Google get in the Motorola Mobile deal in 2014?', | ||
}, | ||
], | ||
}, | ||
], | ||
tools: TOOLS_WITH_RAG, | ||
}; | ||
const result = | ||
await generativeTextModelPreview.generateContentStream(request); | ||
const response = await result.response; | ||
expect(response.candidates![0]).toBeTruthy( | ||
`sys test failure on generateContent with RAG tool, for resp ${JSON.stringify( | ||
response | ||
)}` | ||
); | ||
expect( | ||
response.candidates![0]?.groundingMetadata?.retrievalQueries | ||
).toBeTruthy( | ||
`sys test failure on generateContent with RAG tool, empty groundingMetadata.retrievalQueries, for resp ${JSON.stringify( | ||
response | ||
)}` | ||
); | ||
}); | ||
}); | ||
@@ -719,3 +765,3 @@ | ||
}); | ||
it('should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const request = { | ||
@@ -745,3 +791,3 @@ contents: [ | ||
}); | ||
it('in preview should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('in preview should return a text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const request = { | ||
@@ -770,5 +816,33 @@ contents: [ | ||
}); | ||
it('should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
it('in preview should return grounding metadata when passed a VertexRagStore', async () => { | ||
const request = { | ||
contents: [ | ||
{ | ||
role: 'user', | ||
parts: [ | ||
{ | ||
text: 'How much gain or loss did Google get in the Motorola Mobile deal in 2014?', | ||
}, | ||
], | ||
}, | ||
], | ||
tools: TOOLS_WITH_RAG, | ||
}; | ||
const resp = await generativeTextModelPreview.generateContent(request); | ||
expect(resp.response.candidates![0]).toBeTruthy( | ||
`sys test failure on generateContent with RAG tool, for resp ${JSON.stringify( | ||
resp | ||
)}` | ||
); | ||
expect( | ||
resp.response.candidates![0]?.groundingMetadata?.retrievalQueries | ||
).toBeTruthy( | ||
`sys test failure on generateContent with RAG tool, empty groundingMetadata.retrievalQueries, for resp ${JSON.stringify( | ||
resp | ||
)}` | ||
); | ||
}); | ||
xit('should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const request = { | ||
contents: [ | ||
{role: 'user', parts: [{text: 'What is the weather in Boston?'}]}, | ||
@@ -795,3 +869,3 @@ ], | ||
}); | ||
it('in preview should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
xit('in preview should return a FunctionCall when passed a FunctionDeclaration', async () => { | ||
const request = { | ||
@@ -1023,3 +1097,3 @@ contents: [ | ||
it('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const chat = generativeTextModel.startChat({ | ||
@@ -1072,3 +1146,3 @@ tools: TOOLS_WITH_FUNCTION_DECLARATION, | ||
}); | ||
it('in preview should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
xit('in preview should return a FunctionCall or text when passed a FunctionDeclaration or FunctionResponse', async () => { | ||
const chat = generativeTextModelPreview.startChat({ | ||
@@ -1075,0 +1149,0 @@ tools: TOOLS_WITH_FUNCTION_DECLARATION, |
@@ -17,2 +17,46 @@ import {VertexAI} from '../src/vertex_ai'; | ||
it('no location given, should instantiate VertexAI and VertexAIPreview', () => { | ||
const vertexaiNoLocation = new VertexAI({project: PROJECT}); | ||
const generativeModel = vertexaiNoLocation.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeModelPreview = | ||
vertexaiNoLocation.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
expect(vertexaiNoLocation).toBeInstanceOf(VertexAI); | ||
expect(generativeModel).toBeInstanceOf(GenerativeModel); | ||
expect(generativeModelPreview).toBeInstanceOf(GenerativeModelPreview); | ||
}); | ||
it('location in run time env GOOGLE_CLOUD_REGION, should instantiate VertexAI and VertexAIPreview', () => { | ||
process.env['GOOGLE_CLOUD_REGION'] = 'us-central1'; | ||
const vertexaiNoLocation = new VertexAI({project: PROJECT}); | ||
const generativeModel = vertexaiNoLocation.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeModelPreview = | ||
vertexaiNoLocation.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
expect(vertexaiNoLocation).toBeInstanceOf(VertexAI); | ||
expect(generativeModel).toBeInstanceOf(GenerativeModel); | ||
expect(generativeModelPreview).toBeInstanceOf(GenerativeModelPreview); | ||
}); | ||
it('location in run time env CLOUD_ML_REGION, should instantiate VertexAI and VertexAIPreview', () => { | ||
process.env['CLOUD_ML_REGION'] = 'us-central1'; | ||
const vertexaiNoLocation = new VertexAI({project: PROJECT}); | ||
const generativeModel = vertexaiNoLocation.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
const generativeModelPreview = | ||
vertexaiNoLocation.preview.getGenerativeModel({ | ||
model: 'gemini-pro', | ||
}); | ||
expect(vertexaiNoLocation).toBeInstanceOf(VertexAI); | ||
expect(generativeModel).toBeInstanceOf(GenerativeModel); | ||
expect(generativeModelPreview).toBeInstanceOf(GenerativeModelPreview); | ||
}); | ||
it('given undefined google auth options, should be instantiated', () => { | ||
@@ -19,0 +63,0 @@ expect(vertexai).toBeInstanceOf(VertexAI); |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
Environment variable access
Supply chain riskPackage accesses environment variables, which may be a sign of credential stuffing or data theft.
Found 3 instances in 1 package
License Policy Violation
LicenseThis package is not allowed per your license policy. Review the package's license to ensure compliance.
Found 1 instance in 1 package
Environment variable access
Supply chain riskPackage accesses environment variables, which may be a sign of credential stuffing or data theft.
Found 1 instance in 1 package
1126735
156
20692
428
6