Comparing version 0.2.2 to 0.2.3
/// <reference types="node" /> | ||
import { Table } from 'apache-arrow'; | ||
import { type Schema, Table as ArrowTable } from 'apache-arrow'; | ||
import { type EmbeddingFunction } from './index'; | ||
export declare function convertToTable<T>(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table>; | ||
export declare function convertToTable<T>(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<ArrowTable>; | ||
export declare function fromRecordsToBuffer<T>(data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer>; | ||
export declare function fromTableToBuffer<T>(table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer>; | ||
export declare function createEmptyTable(schema: Schema): ArrowTable; |
@@ -25,4 +25,5 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.fromRecordsToBuffer = exports.convertToTable = void 0; | ||
exports.createEmptyTable = exports.fromTableToBuffer = exports.fromRecordsToBuffer = exports.convertToTable = void 0; | ||
const apache_arrow_1 = require("apache-arrow"); | ||
// Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it. | ||
function convertToTable(data, embeddings) { | ||
@@ -37,4 +38,4 @@ return __awaiter(this, void 0, void 0, function* () { | ||
if (columnsKey === 'vector') { | ||
const listBuilder = newVectorListBuilder(); | ||
const vectorSize = data[0].vector.length; | ||
const listBuilder = newVectorBuilder(vectorSize); | ||
for (const datum of data) { | ||
@@ -55,5 +56,3 @@ if (datum[columnsKey].length !== vectorSize) { | ||
const vectors = yield embeddings.embed(values); | ||
const listBuilder = newVectorListBuilder(); | ||
vectors.map(v => listBuilder.append(v)); | ||
records.vector = listBuilder.finish().toVector(); | ||
records.vector = (0, apache_arrow_1.vectorFromArray)(vectors, newVectorType(vectors[0].length)); | ||
} | ||
@@ -74,9 +73,13 @@ if (typeof values[0] === 'string') { | ||
// Creates a new Arrow ListBuilder that stores a Vector column | ||
function newVectorListBuilder() { | ||
const children = new apache_arrow_1.Field('item', new apache_arrow_1.Float32()); | ||
const list = new apache_arrow_1.List(children); | ||
function newVectorBuilder(dim) { | ||
return (0, apache_arrow_1.makeBuilder)({ | ||
type: list | ||
type: newVectorType(dim) | ||
}); | ||
} | ||
// Creates the Arrow Type for a Vector column with dimension `dim` | ||
function newVectorType(dim) { | ||
const children = new apache_arrow_1.Field('item', new apache_arrow_1.Float32()); | ||
return new apache_arrow_1.FixedSizeList(dim, children); | ||
} | ||
// Converts an Array of records into Arrow IPC format | ||
function fromRecordsToBuffer(data, embeddings) { | ||
@@ -90,1 +93,23 @@ return __awaiter(this, void 0, void 0, function* () { | ||
exports.fromRecordsToBuffer = fromRecordsToBuffer; | ||
// Converts an Arrow Table into Arrow IPC format | ||
function fromTableToBuffer(table, embeddings) { | ||
return __awaiter(this, void 0, void 0, function* () { | ||
if (embeddings !== undefined) { | ||
const source = table.getChild(embeddings.sourceColumn); | ||
if (source === null) { | ||
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`); | ||
} | ||
const vectors = yield embeddings.embed(source.toArray()); | ||
const column = (0, apache_arrow_1.vectorFromArray)(vectors, newVectorType(vectors[0].length)); | ||
table = table.assign(new apache_arrow_1.Table({ vector: column })); | ||
} | ||
const writer = apache_arrow_1.RecordBatchFileWriter.writeAll(table); | ||
return Buffer.from(yield writer.toUint8Array()); | ||
}); | ||
} | ||
exports.fromTableToBuffer = fromTableToBuffer; | ||
// Creates an empty Arrow Table | ||
function createEmptyTable(schema) { | ||
return new apache_arrow_1.Table(schema); | ||
} | ||
exports.createEmptyTable = createEmptyTable; |
@@ -1,2 +0,2 @@ | ||
import { type Table as ArrowTable } from 'apache-arrow'; | ||
import { type Schema, Table as ArrowTable } from 'apache-arrow'; | ||
import type { EmbeddingFunction } from './embedding/embedding_function'; | ||
@@ -15,2 +15,3 @@ import { Query } from './query'; | ||
awsCredentials?: AwsCredentials; | ||
awsRegion?: string; | ||
apiKey?: string; | ||
@@ -20,2 +21,9 @@ region?: string; | ||
} | ||
export interface CreateTableOptions<T> { | ||
name: string; | ||
data?: Array<Record<string, unknown>> | ArrowTable | undefined; | ||
schema?: Schema | undefined; | ||
embeddingFunction?: EmbeddingFunction<T> | undefined; | ||
writeOptions?: WriteOptions | undefined; | ||
} | ||
/** | ||
@@ -43,2 +51,12 @@ * Connect to a LanceDB instance at the given URI | ||
/** | ||
* Creates a new Table, optionally initializing it with new data. | ||
* | ||
* @param {string} name - The name of the table. | ||
* @param data - Array of Records to be inserted into the table | ||
* @param schema - An Arrow Schema that describe this table columns | ||
* @param {EmbeddingFunction} embeddings - An embedding function to use on this table | ||
* @param {WriteOptions} writeOptions - The write options to use when creating the table. | ||
*/ | ||
createTable<T>({ name, data, schema, embeddingFunction, writeOptions }: CreateTableOptions<T>): Promise<Table<T>>; | ||
/** | ||
* Creates a new Table and initialize it with new data. | ||
@@ -75,3 +93,2 @@ * | ||
createTable<T>(name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>, options: WriteOptions): Promise<Table<T>>; | ||
createTableArrow(name: string, table: ArrowTable): Promise<Table>; | ||
/** | ||
@@ -177,4 +194,4 @@ * Drop an existing table. | ||
openTable<T>(name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>; | ||
createTable<T>(name: string, data: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>>; | ||
createTableArrow(name: string, table: ArrowTable): Promise<Table>; | ||
createTable<T>(name: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>>; | ||
private createTableImpl; | ||
/** | ||
@@ -181,0 +198,0 @@ * Drop an existing table. |
@@ -36,2 +36,18 @@ "use strict"; | ||
Object.defineProperty(exports, "OpenAIEmbeddingFunction", { enumerable: true, get: function () { return openai_1.OpenAIEmbeddingFunction; } }); | ||
function getAwsArgs(opts) { | ||
const callArgs = []; | ||
const awsCredentials = opts.awsCredentials; | ||
if (awsCredentials !== undefined) { | ||
callArgs.push(awsCredentials.accessKeyId); | ||
callArgs.push(awsCredentials.secretKey); | ||
callArgs.push(awsCredentials.sessionToken); | ||
} | ||
else { | ||
callArgs.push(undefined); | ||
callArgs.push(undefined); | ||
callArgs.push(undefined); | ||
} | ||
callArgs.push(opts.awsRegion); | ||
return callArgs; | ||
} | ||
function connect(arg) { | ||
@@ -66,7 +82,7 @@ return __awaiter(this, void 0, void 0, function* () { | ||
constructor(db, options) { | ||
this._options = options; | ||
this._options = () => options; | ||
this._db = db; | ||
} | ||
get uri() { | ||
return this._options.uri; | ||
return this._options().uri; | ||
} | ||
@@ -83,17 +99,8 @@ /** | ||
return __awaiter(this, void 0, void 0, function* () { | ||
// TODO: move this thing into rust | ||
const callArgs = [this._db, name]; | ||
if (this._options.awsCredentials !== undefined) { | ||
callArgs.push(this._options.awsCredentials.accessKeyId); | ||
callArgs.push(this._options.awsCredentials.secretKey); | ||
if (this._options.awsCredentials.sessionToken !== undefined) { | ||
callArgs.push(this._options.awsCredentials.sessionToken); | ||
} | ||
} | ||
const tbl = yield databaseOpenTable.call(...callArgs); | ||
const tbl = yield databaseOpenTable.call(this._db, name, ...getAwsArgs(this._options())); | ||
if (embeddings !== undefined) { | ||
return new LocalTable(tbl, name, this._options, embeddings); | ||
return new LocalTable(tbl, name, this._options(), embeddings); | ||
} | ||
else { | ||
return new LocalTable(tbl, name, this._options); | ||
return new LocalTable(tbl, name, this._options()); | ||
} | ||
@@ -103,39 +110,52 @@ }); | ||
createTable(name, data, optsOrEmbedding, opt) { | ||
return __awaiter(this, void 0, void 0, function* () { | ||
if (typeof name === 'string') { | ||
let writeOptions = new DefaultWriteOptions(); | ||
if (opt !== undefined && isWriteOptions(opt)) { | ||
writeOptions = opt; | ||
} | ||
else if (optsOrEmbedding !== undefined && isWriteOptions(optsOrEmbedding)) { | ||
writeOptions = optsOrEmbedding; | ||
} | ||
let embeddings; | ||
if (optsOrEmbedding !== undefined && (0, embedding_function_1.isEmbeddingFunction)(optsOrEmbedding)) { | ||
embeddings = optsOrEmbedding; | ||
} | ||
return yield this.createTableImpl({ name, data, embeddingFunction: embeddings, writeOptions }); | ||
} | ||
return yield this.createTableImpl(name); | ||
}); | ||
} | ||
createTableImpl({ name, data, schema, embeddingFunction, writeOptions = new DefaultWriteOptions() }) { | ||
var _a; | ||
return __awaiter(this, void 0, void 0, function* () { | ||
let writeOptions = new DefaultWriteOptions(); | ||
if (opt !== undefined && isWriteOptions(opt)) { | ||
writeOptions = opt; | ||
let buffer; | ||
function isEmpty(data) { | ||
if (data instanceof apache_arrow_1.Table) { | ||
return data.data.length === 0; | ||
} | ||
return data.length === 0; | ||
} | ||
else if (optsOrEmbedding !== undefined && isWriteOptions(optsOrEmbedding)) { | ||
writeOptions = optsOrEmbedding; | ||
if ((data === undefined) || isEmpty(data)) { | ||
if (schema === undefined) { | ||
throw new Error('Either data or schema needs to defined'); | ||
} | ||
buffer = yield (0, arrow_1.fromTableToBuffer)((0, arrow_1.createEmptyTable)(schema)); | ||
} | ||
let embeddings; | ||
if (optsOrEmbedding !== undefined && (0, embedding_function_1.isEmbeddingFunction)(optsOrEmbedding)) { | ||
embeddings = optsOrEmbedding; | ||
else if (data instanceof apache_arrow_1.Table) { | ||
buffer = yield (0, arrow_1.fromTableToBuffer)(data, embeddingFunction); | ||
} | ||
const createArgs = [this._db, name, yield (0, arrow_1.fromRecordsToBuffer)(data, embeddings), (_a = writeOptions.writeMode) === null || _a === void 0 ? void 0 : _a.toString()]; | ||
if (this._options.awsCredentials !== undefined) { | ||
createArgs.push(this._options.awsCredentials.accessKeyId); | ||
createArgs.push(this._options.awsCredentials.secretKey); | ||
if (this._options.awsCredentials.sessionToken !== undefined) { | ||
createArgs.push(this._options.awsCredentials.sessionToken); | ||
} | ||
else { | ||
// data is Array<Record<...>> | ||
buffer = yield (0, arrow_1.fromRecordsToBuffer)(data, embeddingFunction); | ||
} | ||
const tbl = yield tableCreate.call(...createArgs); | ||
if (embeddings !== undefined) { | ||
return new LocalTable(tbl, name, this._options, embeddings); | ||
const tbl = yield tableCreate.call(this._db, name, buffer, (_a = writeOptions === null || writeOptions === void 0 ? void 0 : writeOptions.writeMode) === null || _a === void 0 ? void 0 : _a.toString(), ...getAwsArgs(this._options())); | ||
if (embeddingFunction !== undefined) { | ||
return new LocalTable(tbl, name, this._options(), embeddingFunction); | ||
} | ||
else { | ||
return new LocalTable(tbl, name, this._options); | ||
return new LocalTable(tbl, name, this._options()); | ||
} | ||
}); | ||
} | ||
createTableArrow(name, table) { | ||
return __awaiter(this, void 0, void 0, function* () { | ||
const writer = apache_arrow_1.RecordBatchFileWriter.writeAll(table); | ||
yield tableCreate.call(this._db, name, Buffer.from(yield writer.toUint8Array())); | ||
return yield this.openTable(name); | ||
}); | ||
} | ||
/** | ||
@@ -157,3 +177,3 @@ * Drop an existing table. | ||
this._embeddings = embeddings; | ||
this._options = options; | ||
this._options = () => options; | ||
} | ||
@@ -178,11 +198,3 @@ get name() { | ||
return __awaiter(this, void 0, void 0, function* () { | ||
const callArgs = [this._tbl, yield (0, arrow_1.fromRecordsToBuffer)(data, this._embeddings), WriteMode.Append.toString()]; | ||
if (this._options.awsCredentials !== undefined) { | ||
callArgs.push(this._options.awsCredentials.accessKeyId); | ||
callArgs.push(this._options.awsCredentials.secretKey); | ||
if (this._options.awsCredentials.sessionToken !== undefined) { | ||
callArgs.push(this._options.awsCredentials.sessionToken); | ||
} | ||
} | ||
return tableAdd.call(...callArgs).then((newTable) => { this._tbl = newTable; }); | ||
return tableAdd.call(this._tbl, yield (0, arrow_1.fromRecordsToBuffer)(data, this._embeddings), WriteMode.Append.toString(), ...getAwsArgs(this._options())).then((newTable) => { this._tbl = newTable; }); | ||
}); | ||
@@ -198,11 +210,3 @@ } | ||
return __awaiter(this, void 0, void 0, function* () { | ||
const callArgs = [this._tbl, yield (0, arrow_1.fromRecordsToBuffer)(data, this._embeddings), WriteMode.Overwrite.toString()]; | ||
if (this._options.awsCredentials !== undefined) { | ||
callArgs.push(this._options.awsCredentials.accessKeyId); | ||
callArgs.push(this._options.awsCredentials.secretKey); | ||
if (this._options.awsCredentials.sessionToken !== undefined) { | ||
callArgs.push(this._options.awsCredentials.sessionToken); | ||
} | ||
} | ||
return tableAdd.call(...callArgs).then((newTable) => { this._tbl = newTable; }); | ||
return tableAdd.call(this._tbl, yield (0, arrow_1.fromRecordsToBuffer)(data, this._embeddings), WriteMode.Overwrite.toString(), ...getAwsArgs(this._options())).then((newTable) => { this._tbl = newTable; }); | ||
}); | ||
@@ -209,0 +213,0 @@ } |
@@ -1,4 +0,3 @@ | ||
import { type EmbeddingFunction, type Table, type VectorIndexParams, type Connection, type ConnectionOptions } from '../index'; | ||
import { type EmbeddingFunction, type Table, type VectorIndexParams, type Connection, type ConnectionOptions, type CreateTableOptions, type WriteOptions } from '../index'; | ||
import { Query } from '../query'; | ||
import { type Table as ArrowTable } from 'apache-arrow'; | ||
import { HttpLancedbClient } from './client'; | ||
@@ -16,5 +15,3 @@ /** | ||
openTable<T>(name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>; | ||
createTable(name: string, data: Array<Record<string, unknown>>): Promise<Table>; | ||
createTable<T>(name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>): Promise<Table<T>>; | ||
createTableArrow(name: string, table: ArrowTable): Promise<Table>; | ||
createTable<T>(name: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>>; | ||
dropTable(name: string): Promise<void>; | ||
@@ -21,0 +18,0 @@ } |
@@ -70,3 +70,3 @@ "use strict"; | ||
} | ||
createTable(name, data, embeddings) { | ||
createTable(name, data, optsOrEmbedding, opt) { | ||
return __awaiter(this, void 0, void 0, function* () { | ||
@@ -76,7 +76,2 @@ throw new Error('Not implemented'); | ||
} | ||
createTableArrow(name, table) { | ||
return __awaiter(this, void 0, void 0, function* () { | ||
throw new Error('Not implemented'); | ||
}); | ||
} | ||
dropTable(name) { | ||
@@ -83,0 +78,0 @@ return __awaiter(this, void 0, void 0, function* () { |
@@ -58,3 +58,5 @@ "use strict"; | ||
const table = yield createTestDB(opts, 2, 20); | ||
console.log(table); | ||
const con = yield lancedb.connect(opts); | ||
console.log(con); | ||
chai_1.assert.equal(con.uri, opts.uri); | ||
@@ -81,4 +83,4 @@ const results = yield table.search([0.1, 0.3]).limit(5).execute(); | ||
} | ||
return yield con.createTable('vectors', data); | ||
return yield con.createTable('vectors_2', data); | ||
}); | ||
} |
@@ -31,2 +31,3 @@ "use strict"; | ||
const index_1 = require("../index"); | ||
const apache_arrow_1 = require("apache-arrow"); | ||
const expect = chai.expect; | ||
@@ -136,2 +137,35 @@ const assert = chai.assert; | ||
(0, mocha_1.describe)('when creating a new dataset', function () { | ||
it('create an empty table', function () { | ||
return __awaiter(this, void 0, void 0, function* () { | ||
const dir = yield (0, temp_1.track)().mkdir('lancejs'); | ||
const con = yield lancedb.connect(dir); | ||
const schema = new apache_arrow_1.Schema([new apache_arrow_1.Field('id', new apache_arrow_1.Int32()), new apache_arrow_1.Field('name', new apache_arrow_1.Utf8())]); | ||
const table = yield con.createTable({ name: 'vectors', schema }); | ||
assert.equal(table.name, 'vectors'); | ||
assert.deepEqual(yield con.tableNames(), ['vectors']); | ||
}); | ||
}); | ||
it('create a table with a empty data array', function () { | ||
return __awaiter(this, void 0, void 0, function* () { | ||
const dir = yield (0, temp_1.track)().mkdir('lancejs'); | ||
const con = yield lancedb.connect(dir); | ||
const schema = new apache_arrow_1.Schema([new apache_arrow_1.Field('id', new apache_arrow_1.Int32()), new apache_arrow_1.Field('name', new apache_arrow_1.Utf8())]); | ||
const table = yield con.createTable({ name: 'vectors', schema, data: [] }); | ||
assert.equal(table.name, 'vectors'); | ||
assert.deepEqual(yield con.tableNames(), ['vectors']); | ||
}); | ||
}); | ||
it('create a table from an Arrow Table', function () { | ||
return __awaiter(this, void 0, void 0, function* () { | ||
const dir = yield (0, temp_1.track)().mkdir('lancejs'); | ||
const con = yield lancedb.connect(dir); | ||
const i32s = new Int32Array(new Array(10)); | ||
const i32 = (0, apache_arrow_1.makeVector)(i32s); | ||
const data = new apache_arrow_1.Table({ vector: i32 }); | ||
const table = yield con.createTable({ name: 'vectors', data }); | ||
assert.equal(table.name, 'vectors'); | ||
assert.equal(yield table.countRows(), 10); | ||
assert.deepEqual(yield con.tableNames(), ['vectors']); | ||
}); | ||
}); | ||
it('creates a new table from javascript objects', function () { | ||
@@ -297,2 +331,15 @@ return __awaiter(this, void 0, void 0, function* () { | ||
}); | ||
it('should create embeddings for Arrow Table', function () { | ||
return __awaiter(this, void 0, void 0, function* () { | ||
const dir = yield (0, temp_1.track)().mkdir('lancejs'); | ||
const con = yield lancedb.connect(dir); | ||
const embeddingFunction = new TextEmbedding('name'); | ||
const names = (0, apache_arrow_1.vectorFromArray)(['foo', 'bar'], new apache_arrow_1.Utf8()); | ||
const data = new apache_arrow_1.Table({ name: names }); | ||
const table = yield con.createTable({ name: 'vectors', data, embeddingFunction }); | ||
assert.equal(table.name, 'vectors'); | ||
const results = yield table.search('foo').execute(); | ||
assert.equal(results.length, 2); | ||
}); | ||
}); | ||
}); | ||
@@ -299,0 +346,0 @@ }); |
{ | ||
"name": "vectordb", | ||
"version": "0.2.2", | ||
"version": "0.2.3", | ||
"description": " Serverless, low-latency vector database for AI applications", | ||
@@ -81,8 +81,8 @@ "main": "dist/index.js", | ||
"optionalDependencies": { | ||
"@lancedb/vectordb-darwin-arm64": "0.2.2", | ||
"@lancedb/vectordb-darwin-x64": "0.2.2", | ||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.2", | ||
"@lancedb/vectordb-linux-x64-gnu": "0.2.2", | ||
"@lancedb/vectordb-win32-x64-msvc": "0.2.2" | ||
"@lancedb/vectordb-darwin-arm64": "0.2.3", | ||
"@lancedb/vectordb-darwin-x64": "0.2.3", | ||
"@lancedb/vectordb-linux-arm64-gnu": "0.2.3", | ||
"@lancedb/vectordb-linux-x64-gnu": "0.2.3", | ||
"@lancedb/vectordb-win32-x64-msvc": "0.2.3" | ||
} | ||
} |
@@ -16,14 +16,15 @@ // Copyright 2023 Lance Developers. | ||
import { | ||
Field, | ||
Field, type FixedSizeListBuilder, | ||
Float32, | ||
List, type ListBuilder, | ||
makeBuilder, | ||
RecordBatchFileWriter, | ||
Table, Utf8, | ||
Utf8, | ||
type Vector, | ||
vectorFromArray | ||
FixedSizeList, | ||
vectorFromArray, type Schema, Table as ArrowTable | ||
} from 'apache-arrow' | ||
import { type EmbeddingFunction } from './index' | ||
export async function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table> { | ||
// Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it. | ||
export async function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<ArrowTable> { | ||
if (data.length === 0) { | ||
@@ -38,4 +39,4 @@ throw new Error('At least one record needs to be provided') | ||
if (columnsKey === 'vector') { | ||
const listBuilder = newVectorListBuilder() | ||
const vectorSize = (data[0].vector as any[]).length | ||
const listBuilder = newVectorBuilder(vectorSize) | ||
for (const datum of data) { | ||
@@ -57,5 +58,3 @@ if ((datum[columnsKey] as any[]).length !== vectorSize) { | ||
const vectors = await embeddings.embed(values as T[]) | ||
const listBuilder = newVectorListBuilder() | ||
vectors.map(v => listBuilder.append(v)) | ||
records.vector = listBuilder.finish().toVector() | ||
records.vector = vectorFromArray(vectors, newVectorType(vectors[0].length)) | ||
} | ||
@@ -72,14 +71,19 @@ | ||
return new Table(records) | ||
return new ArrowTable(records) | ||
} | ||
// Creates a new Arrow ListBuilder that stores a Vector column | ||
function newVectorListBuilder (): ListBuilder<Float32, any> { | ||
const children = new Field<Float32>('item', new Float32()) | ||
const list = new List(children) | ||
function newVectorBuilder (dim: number): FixedSizeListBuilder<Float32> { | ||
return makeBuilder({ | ||
type: list | ||
type: newVectorType(dim) | ||
}) | ||
} | ||
// Creates the Arrow Type for a Vector column with dimension `dim` | ||
function newVectorType (dim: number): FixedSizeList<Float32> { | ||
const children = new Field<Float32>('item', new Float32()) | ||
return new FixedSizeList(dim, children) | ||
} | ||
// Converts an Array of records into Arrow IPC format | ||
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> { | ||
@@ -90,1 +94,23 @@ const table = await convertToTable(data, embeddings) | ||
} | ||
// Converts an Arrow Table into Arrow IPC format | ||
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> { | ||
if (embeddings !== undefined) { | ||
const source = table.getChild(embeddings.sourceColumn) | ||
if (source === null) { | ||
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`) | ||
} | ||
const vectors = await embeddings.embed(source.toArray() as T[]) | ||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length)) | ||
table = table.assign(new ArrowTable({ vector: column })) | ||
} | ||
const writer = RecordBatchFileWriter.writeAll(table) | ||
return Buffer.from(await writer.toUint8Array()) | ||
} | ||
// Creates an empty Arrow Table | ||
export function createEmptyTable (schema: Schema): ArrowTable { | ||
return new ArrowTable(schema) | ||
} |
180
src/index.ts
@@ -16,6 +16,6 @@ // Copyright 2023 Lance Developers. | ||
import { | ||
RecordBatchFileWriter, | ||
type Table as ArrowTable | ||
type Schema, | ||
Table as ArrowTable | ||
} from 'apache-arrow' | ||
import { fromRecordsToBuffer } from './arrow' | ||
import { createEmptyTable, fromRecordsToBuffer, fromTableToBuffer } from './arrow' | ||
import type { EmbeddingFunction } from './embedding/embedding_function' | ||
@@ -46,2 +46,4 @@ import { RemoteConnection } from './remote' | ||
awsRegion?: string | ||
// API key for the remote connections | ||
@@ -56,2 +58,36 @@ apiKey?: string | ||
function getAwsArgs (opts: ConnectionOptions): any[] { | ||
const callArgs = [] | ||
const awsCredentials = opts.awsCredentials | ||
if (awsCredentials !== undefined) { | ||
callArgs.push(awsCredentials.accessKeyId) | ||
callArgs.push(awsCredentials.secretKey) | ||
callArgs.push(awsCredentials.sessionToken) | ||
} else { | ||
callArgs.push(undefined) | ||
callArgs.push(undefined) | ||
callArgs.push(undefined) | ||
} | ||
callArgs.push(opts.awsRegion) | ||
return callArgs | ||
} | ||
export interface CreateTableOptions<T> { | ||
// Name of Table | ||
name: string | ||
// Data to insert into the Table | ||
data?: Array<Record<string, unknown>> | ArrowTable | undefined | ||
// Optional Arrow Schema for this table | ||
schema?: Schema | undefined | ||
// Optional embedding function used to create embeddings | ||
embeddingFunction?: EmbeddingFunction<T> | undefined | ||
// WriteOptions for this operation | ||
writeOptions?: WriteOptions | undefined | ||
} | ||
/** | ||
@@ -104,2 +140,13 @@ * Connect to a LanceDB instance at the given URI | ||
/** | ||
* Creates a new Table, optionally initializing it with new data. | ||
* | ||
* @param {string} name - The name of the table. | ||
* @param data - Array of Records to be inserted into the table | ||
* @param schema - An Arrow Schema that describe this table columns | ||
* @param {EmbeddingFunction} embeddings - An embedding function to use on this table | ||
* @param {WriteOptions} writeOptions - The write options to use when creating the table. | ||
*/ | ||
createTable<T> ({ name, data, schema, embeddingFunction, writeOptions }: CreateTableOptions<T>): Promise<Table<T>> | ||
/** | ||
* Creates a new Table and initialize it with new data. | ||
@@ -139,4 +186,2 @@ * | ||
createTableArrow(name: string, table: ArrowTable): Promise<Table> | ||
/** | ||
@@ -229,7 +274,7 @@ * Drop an existing table. | ||
export class LocalConnection implements Connection { | ||
private readonly _options: ConnectionOptions | ||
private readonly _options: () => ConnectionOptions | ||
private readonly _db: any | ||
constructor (db: any, options: ConnectionOptions) { | ||
this._options = options | ||
this._options = () => options | ||
this._db = db | ||
@@ -239,3 +284,3 @@ } | ||
get uri (): string { | ||
return this._options.uri | ||
return this._options().uri | ||
} | ||
@@ -266,55 +311,64 @@ | ||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> { | ||
// TODO: move this thing into rust | ||
const callArgs = [this._db, name] | ||
if (this._options.awsCredentials !== undefined) { | ||
callArgs.push(this._options.awsCredentials.accessKeyId) | ||
callArgs.push(this._options.awsCredentials.secretKey) | ||
if (this._options.awsCredentials.sessionToken !== undefined) { | ||
callArgs.push(this._options.awsCredentials.sessionToken) | ||
} | ||
} | ||
const tbl = await databaseOpenTable.call(...callArgs) | ||
const tbl = await databaseOpenTable.call(this._db, name, ...getAwsArgs(this._options())) | ||
if (embeddings !== undefined) { | ||
return new LocalTable(tbl, name, this._options, embeddings) | ||
return new LocalTable(tbl, name, this._options(), embeddings) | ||
} else { | ||
return new LocalTable(tbl, name, this._options) | ||
return new LocalTable(tbl, name, this._options()) | ||
} | ||
} | ||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> { | ||
let writeOptions: WriteOptions = new DefaultWriteOptions() | ||
if (opt !== undefined && isWriteOptions(opt)) { | ||
writeOptions = opt | ||
} else if (optsOrEmbedding !== undefined && isWriteOptions(optsOrEmbedding)) { | ||
writeOptions = optsOrEmbedding | ||
async createTable<T> (name: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> { | ||
if (typeof name === 'string') { | ||
let writeOptions: WriteOptions = new DefaultWriteOptions() | ||
if (opt !== undefined && isWriteOptions(opt)) { | ||
writeOptions = opt | ||
} else if (optsOrEmbedding !== undefined && isWriteOptions(optsOrEmbedding)) { | ||
writeOptions = optsOrEmbedding | ||
} | ||
let embeddings: undefined | EmbeddingFunction<T> | ||
if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) { | ||
embeddings = optsOrEmbedding | ||
} | ||
return await this.createTableImpl({ name, data, embeddingFunction: embeddings, writeOptions }) | ||
} | ||
return await this.createTableImpl(name) | ||
} | ||
let embeddings: undefined | EmbeddingFunction<T> | ||
if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) { | ||
embeddings = optsOrEmbedding | ||
private async createTableImpl<T> ({ name, data, schema, embeddingFunction, writeOptions = new DefaultWriteOptions() }: { | ||
name: string | ||
data?: Array<Record<string, unknown>> | ArrowTable | undefined | ||
schema?: Schema | undefined | ||
embeddingFunction?: EmbeddingFunction<T> | undefined | ||
writeOptions?: WriteOptions | undefined | ||
}): Promise<Table<T>> { | ||
let buffer: Buffer | ||
function isEmpty (data: Array<Record<string, unknown>> | ArrowTable<any>): boolean { | ||
if (data instanceof ArrowTable) { | ||
return data.data.length === 0 | ||
} | ||
return data.length === 0 | ||
} | ||
const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), writeOptions.writeMode?.toString()] | ||
if (this._options.awsCredentials !== undefined) { | ||
createArgs.push(this._options.awsCredentials.accessKeyId) | ||
createArgs.push(this._options.awsCredentials.secretKey) | ||
if (this._options.awsCredentials.sessionToken !== undefined) { | ||
createArgs.push(this._options.awsCredentials.sessionToken) | ||
if ((data === undefined) || isEmpty(data)) { | ||
if (schema === undefined) { | ||
throw new Error('Either data or schema needs to defined') | ||
} | ||
buffer = await fromTableToBuffer(createEmptyTable(schema)) | ||
} else if (data instanceof ArrowTable) { | ||
buffer = await fromTableToBuffer(data, embeddingFunction) | ||
} else { | ||
// data is Array<Record<...>> | ||
buffer = await fromRecordsToBuffer(data, embeddingFunction) | ||
} | ||
const tbl = await tableCreate.call(...createArgs) | ||
if (embeddings !== undefined) { | ||
return new LocalTable(tbl, name, this._options, embeddings) | ||
const tbl = await tableCreate.call(this._db, name, buffer, writeOptions?.writeMode?.toString(), ...getAwsArgs(this._options())) | ||
if (embeddingFunction !== undefined) { | ||
return new LocalTable(tbl, name, this._options(), embeddingFunction) | ||
} else { | ||
return new LocalTable(tbl, name, this._options) | ||
return new LocalTable(tbl, name, this._options()) | ||
} | ||
} | ||
async createTableArrow (name: string, table: ArrowTable): Promise<Table> { | ||
const writer = RecordBatchFileWriter.writeAll(table) | ||
await tableCreate.call(this._db, name, Buffer.from(await writer.toUint8Array())) | ||
return await this.openTable(name) | ||
} | ||
/** | ||
@@ -333,3 +387,3 @@ * Drop an existing table. | ||
private readonly _embeddings?: EmbeddingFunction<T> | ||
private readonly _options: ConnectionOptions | ||
private readonly _options: () => ConnectionOptions | ||
@@ -348,3 +402,3 @@ constructor (tbl: any, name: string, options: ConnectionOptions) | ||
this._embeddings = embeddings | ||
this._options = options | ||
this._options = () => options | ||
} | ||
@@ -371,11 +425,8 @@ | ||
async add (data: Array<Record<string, unknown>>): Promise<number> { | ||
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()] | ||
if (this._options.awsCredentials !== undefined) { | ||
callArgs.push(this._options.awsCredentials.accessKeyId) | ||
callArgs.push(this._options.awsCredentials.secretKey) | ||
if (this._options.awsCredentials.sessionToken !== undefined) { | ||
callArgs.push(this._options.awsCredentials.sessionToken) | ||
} | ||
} | ||
return tableAdd.call(...callArgs).then((newTable: any) => { this._tbl = newTable }) | ||
return tableAdd.call( | ||
this._tbl, | ||
await fromRecordsToBuffer(data, this._embeddings), | ||
WriteMode.Append.toString(), | ||
...getAwsArgs(this._options()) | ||
).then((newTable: any) => { this._tbl = newTable }) | ||
} | ||
@@ -390,11 +441,8 @@ | ||
async overwrite (data: Array<Record<string, unknown>>): Promise<number> { | ||
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()] | ||
if (this._options.awsCredentials !== undefined) { | ||
callArgs.push(this._options.awsCredentials.accessKeyId) | ||
callArgs.push(this._options.awsCredentials.secretKey) | ||
if (this._options.awsCredentials.sessionToken !== undefined) { | ||
callArgs.push(this._options.awsCredentials.sessionToken) | ||
} | ||
} | ||
return tableAdd.call(...callArgs).then((newTable: any) => { this._tbl = newTable }) | ||
return tableAdd.call( | ||
this._tbl, | ||
await fromRecordsToBuffer(data, this._embeddings), | ||
WriteMode.Overwrite.toString(), | ||
...getAwsArgs(this._options()) | ||
).then((newTable: any) => { this._tbl = newTable }) | ||
} | ||
@@ -401,0 +449,0 @@ |
@@ -17,7 +17,7 @@ // Copyright 2023 LanceDB Developers. | ||
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection, | ||
type ConnectionOptions | ||
type ConnectionOptions, type CreateTableOptions, type WriteOptions | ||
} from '../index' | ||
import { Query } from '../query' | ||
import { type Table as ArrowTable, Vector } from 'apache-arrow' | ||
import { Vector } from 'apache-arrow' | ||
import { HttpLancedbClient } from './client' | ||
@@ -70,12 +70,6 @@ | ||
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table> | ||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>): Promise<Table<T>> | ||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> { | ||
async createTable<T> (name: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> { | ||
throw new Error('Not implemented') | ||
} | ||
async createTableArrow (name: string, table: ArrowTable): Promise<Table> { | ||
throw new Error('Not implemented') | ||
} | ||
async dropTable (name: string): Promise<void> { | ||
@@ -82,0 +76,0 @@ await this._client.post(`/v1/table/${name}/drop/`) |
@@ -50,3 +50,5 @@ // Copyright 2023 Lance Developers. | ||
const table = await createTestDB(opts, 2, 20) | ||
console.log(table) | ||
const con = await lancedb.connect(opts) | ||
console.log(con) | ||
assert.equal(con.uri, opts.uri) | ||
@@ -74,3 +76,3 @@ | ||
return await con.createTable('vectors', data) | ||
return await con.createTable('vectors_2', data) | ||
} |
@@ -22,2 +22,3 @@ // Copyright 2023 LanceDB Developers. | ||
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index' | ||
import { Field, Int32, makeVector, Schema, Utf8, Table as ArrowTable, vectorFromArray } from 'apache-arrow' | ||
@@ -123,2 +124,41 @@ const expect = chai.expect | ||
describe('when creating a new dataset', function () { | ||
it('create an empty table', async function () { | ||
const dir = await track().mkdir('lancejs') | ||
const con = await lancedb.connect(dir) | ||
const schema = new Schema( | ||
[new Field('id', new Int32()), new Field('name', new Utf8())] | ||
) | ||
const table = await con.createTable({ name: 'vectors', schema }) | ||
assert.equal(table.name, 'vectors') | ||
assert.deepEqual(await con.tableNames(), ['vectors']) | ||
}) | ||
it('create a table with a empty data array', async function () { | ||
const dir = await track().mkdir('lancejs') | ||
const con = await lancedb.connect(dir) | ||
const schema = new Schema( | ||
[new Field('id', new Int32()), new Field('name', new Utf8())] | ||
) | ||
const table = await con.createTable({ name: 'vectors', schema, data: [] }) | ||
assert.equal(table.name, 'vectors') | ||
assert.deepEqual(await con.tableNames(), ['vectors']) | ||
}) | ||
it('create a table from an Arrow Table', async function () { | ||
const dir = await track().mkdir('lancejs') | ||
const con = await lancedb.connect(dir) | ||
const i32s = new Int32Array(new Array<number>(10)) | ||
const i32 = makeVector(i32s) | ||
const data = new ArrowTable({ vector: i32 }) | ||
const table = await con.createTable({ name: 'vectors', data }) | ||
assert.equal(table.name, 'vectors') | ||
assert.equal(await table.countRows(), 10) | ||
assert.deepEqual(await con.tableNames(), ['vectors']) | ||
}) | ||
it('creates a new table from javascript objects', async function () { | ||
@@ -296,2 +336,16 @@ const dir = await track().mkdir('lancejs') | ||
}) | ||
it('should create embeddings for Arrow Table', async function () { | ||
const dir = await track().mkdir('lancejs') | ||
const con = await lancedb.connect(dir) | ||
const embeddingFunction = new TextEmbedding('name') | ||
const names = vectorFromArray(['foo', 'bar'], new Utf8()) | ||
const data = new ArrowTable({ name: names }) | ||
const table = await con.createTable({ name: 'vectors', data, embeddingFunction }) | ||
assert.equal(table.name, 'vectors') | ||
const results = await table.search('foo').execute() | ||
assert.equal(results.length, 2) | ||
}) | ||
}) | ||
@@ -298,0 +352,0 @@ }) |
156721
3732