Big News: Socket raises $60M Series C at a $1B valuation to secure software supply chains for AI-driven development.Announcement
Sign In

ai-sql

Package Overview
Dependencies
Maintainers
1
Versions
7
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

ai-sql - npm Package Compare versions

Comparing version
0.0.1-4
to
0.0.1-6
+280
dist/index.cjs
'use strict';
var pg = require('pg');
var ai = require('ai');
var zod = require('zod');
var mysql$1 = require('mysql2');
var sqlite3 = require('sqlite3');
function _interopNamespaceDefault(e) {
var n = Object.create(null);
if (e) {
Object.keys(e).forEach(function (k) {
if (k !== 'default') {
var d = Object.getOwnPropertyDescriptor(e, k);
Object.defineProperty(n, k, d.get ? d : {
enumerable: true,
get: function () { return e[k]; }
});
}
});
}
n.default = e;
return Object.freeze(n);
}
var mysql__namespace = /*#__PURE__*/_interopNamespaceDefault(mysql$1);
class PostgresTool {
url;
client;
constructor(dbUrl) {
this.url = dbUrl;
}
async initialize() {
this.client = new pg.Client(this.url);
await this.client.connect();
}
async describe() {
const result = await this.client.query(`
SELECT table_schema,
table_name,
string_agg(
format('%s %s %s', column_name,
CASE
WHEN data_type = 'character varying' THEN 'VARCHAR(' || character_maximum_length || ')'
WHEN data_type = 'numeric' THEN 'NUMERIC(' || numeric_precision || ',' || numeric_scale || ')'
WHEN data_type = 'character' THEN 'CHAR(' || character_maximum_length || ')'
ELSE data_type
END,
CASE WHEN is_nullable = 'YES' THEN 'NULL' ELSE 'NOT NULL' END),
',
' ORDER BY ordinal_position
) AS columns
FROM information_schema.columns
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
GROUP BY table_schema, table_name
ORDER BY table_schema, table_name;
`);
const createTableStatements = result.rows.map((row) => {
const { table_schema, table_name, columns } = row;
return `
CREATE TABLE ${table_schema}.${table_name} (
${columns}
);
`.trim();
});
const schema = createTableStatements.join("\n\n");
return {
database: "PostgreSQL",
description: schema
};
}
async query(query) {
const result = await this.client.query(query);
const res = Array.isArray(result) ? result : [result];
return res.map((r) => r.rows);
}
}
const descriptionTemplate = (schema) => `Query a database with the following schema:
${JSON.stringify(schema)}`;
async function sqlTool(db, { notes } = {}) {
await db.initialize();
const schema = await db.describe();
schema.notes = notes;
return ai.tool({
description: descriptionTemplate(schema),
execute: async ({ query }) => {
return await db.query(query);
},
parameters: zod.z.object({
query: zod.z.string().describe(
`${schema.database} Query to execute.
Notes: ${schema.notes?.join(
", "
)}`
)
})
});
}
class MySQLTool {
url;
client;
constructor(dbUrl) {
this.url = dbUrl;
}
async initialize() {
this.client = mysql__namespace.createConnection(this.url);
this.client.connect();
}
async describe() {
const res = await new Promise(
(resolve, reject) => this.client.query(
{
sql: `SELECT table_schema,
table_name,
GROUP_CONCAT(
CONCAT(
column_name, ' ',
CASE
WHEN data_type = 'varchar' THEN CONCAT('VARCHAR(', character_maximum_length, ')')
WHEN data_type = 'char' THEN CONCAT('CHAR(', character_maximum_length, ')')
WHEN data_type = 'decimal' THEN CONCAT('DECIMAL(', numeric_precision, ',', numeric_scale, ')')
ELSE data_type
END, ' ',
IF(is_nullable = 'YES', 'NULL', 'NOT NULL')
) ORDER BY ordinal_position
SEPARATOR '
' -- Newline separator
) AS columns
FROM information_schema.columns
WHERE table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')
GROUP BY table_schema, table_name
ORDER BY table_schema, table_name;
`
},
(err, rows, _fields) => {
if (err) {
reject(err);
}
resolve(rows);
}
)
);
const createTableStatements = res.map((row) => {
const { table_schema, table_name, columns } = row;
return `
CREATE TABLE ${table_schema}.${table_name} (
${columns}
);
`.trim();
});
const schema = createTableStatements.join("\n\n");
return {
database: "MySQL",
description: schema
};
}
async query(query) {
const res = await new Promise(
(resolve, reject) => this.client.query({ sql: query }, (err, res2) => {
if (err) {
reject(err);
}
resolve(res2);
})
);
return res;
}
}
function getDatabaseSchema(db) {
return new Promise((resolve, reject) => {
const getTablesQuery = `
SELECT name
FROM sqlite_master
WHERE type = 'table'
AND name NOT LIKE 'sqlite_%'
ORDER BY name;
`;
db.all(getTablesQuery, [], (err, tables) => {
if (err) {
return reject(`Failed to retrieve tables: ${err.message}`);
}
if (!tables.length) {
return resolve("No tables in this database.");
}
let schemaOutput = "";
let pendingTables = tables.length;
tables.forEach((name) => {
db.all(`PRAGMA table_info(${name})`, [], (err2, columns) => {
if (err2) {
return reject(
`Failed to get table_info for ${name}: ${err2.message}`
);
}
schemaOutput += `Table: ${name}
`;
columns.forEach((col) => {
let columnStr = ` ${col.name} ${col.type}`;
if (col.pk === 1) {
columnStr += " PRIMARY KEY";
}
if (col.notnull === 1) {
columnStr += " NOT NULL";
}
if (col.dflt_value !== null) {
columnStr += ` DEFAULT ${col.dflt_value}`;
}
schemaOutput += columnStr + "\n";
});
schemaOutput += "\n";
pendingTables -= 1;
if (pendingTables === 0) {
resolve(schemaOutput.trim());
}
});
});
});
});
}
class SQLiteTool {
path;
db;
constructor(dbPath) {
this.path = dbPath;
}
static inMemory() {
return new SQLiteTool(":memory:");
}
async initialize() {
this.db = new sqlite3.Database(this.path);
}
async describe() {
const description = await getDatabaseSchema(this.db);
return {
database: "PostgreSQL",
description: `
sqlite3 database at ${this.path}
schema:
${description}
`.trim()
};
}
async query(query) {
return await new Promise((resolve, reject) => {
this.db.run(query, (res, err) => {
if (err !== null) {
reject(err);
} else {
res.all((err2, rows) => {
if (err2 !== null) {
reject(err2);
} else {
resolve(rows);
}
});
}
});
});
}
}
async function postgres(dbUrl, options) {
return await sqlTool(new PostgresTool(dbUrl), options);
}
async function mysql(dbUrl, options) {
return await sqlTool(new MySQLTool(dbUrl), options);
}
async function sqlite(dbPath, options) {
return await sqlTool(new SQLiteTool(dbPath), options);
}
exports.mysql = mysql;
exports.postgres = postgres;
exports.sqlite = sqlite;
import * as ai from 'ai';
import * as zod from 'zod';
interface SqlToolOptions {
notes?: string[];
}
declare function postgres(dbUrl: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{
query: zod.ZodString;
}, "strip", zod.ZodTypeAny, {
query?: string;
}, {
query?: string;
}>, unknown[]> & {
execute: (args: {
query?: string;
}, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>;
}>;
declare function mysql(dbUrl: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{
query: zod.ZodString;
}, "strip", zod.ZodTypeAny, {
query?: string;
}, {
query?: string;
}>, unknown[]> & {
execute: (args: {
query?: string;
}, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>;
}>;
declare function sqlite(dbPath: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{
query: zod.ZodString;
}, "strip", zod.ZodTypeAny, {
query?: string;
}, {
query?: string;
}>, unknown[]> & {
execute: (args: {
query?: string;
}, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>;
}>;
export { mysql, postgres, sqlite };
import * as ai from 'ai';
import * as zod from 'zod';
interface SqlToolOptions {
notes?: string[];
}
declare function postgres(dbUrl: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{
query: zod.ZodString;
}, "strip", zod.ZodTypeAny, {
query?: string;
}, {
query?: string;
}>, unknown[]> & {
execute: (args: {
query?: string;
}, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>;
}>;
declare function mysql(dbUrl: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{
query: zod.ZodString;
}, "strip", zod.ZodTypeAny, {
query?: string;
}, {
query?: string;
}>, unknown[]> & {
execute: (args: {
query?: string;
}, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>;
}>;
declare function sqlite(dbPath: string, options?: SqlToolOptions): Promise<ai.CoreTool<zod.ZodObject<{
query: zod.ZodString;
}, "strip", zod.ZodTypeAny, {
query?: string;
}, {
query?: string;
}>, unknown[]> & {
execute: (args: {
query?: string;
}, options: ai.ToolExecutionOptions) => PromiseLike<unknown[]>;
}>;
export { mysql, postgres, sqlite };
import { Client } from 'pg';
import { tool } from 'ai';
import { z } from 'zod';
import * as mysql$1 from 'mysql2';
import { Database } from 'sqlite3';
class PostgresTool {
url;
client;
constructor(dbUrl) {
this.url = dbUrl;
}
async initialize() {
this.client = new Client(this.url);
await this.client.connect();
}
async describe() {
const result = await this.client.query(`
SELECT table_schema,
table_name,
string_agg(
format('%s %s %s', column_name,
CASE
WHEN data_type = 'character varying' THEN 'VARCHAR(' || character_maximum_length || ')'
WHEN data_type = 'numeric' THEN 'NUMERIC(' || numeric_precision || ',' || numeric_scale || ')'
WHEN data_type = 'character' THEN 'CHAR(' || character_maximum_length || ')'
ELSE data_type
END,
CASE WHEN is_nullable = 'YES' THEN 'NULL' ELSE 'NOT NULL' END),
',
' ORDER BY ordinal_position
) AS columns
FROM information_schema.columns
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
GROUP BY table_schema, table_name
ORDER BY table_schema, table_name;
`);
const createTableStatements = result.rows.map((row) => {
const { table_schema, table_name, columns } = row;
return `
CREATE TABLE ${table_schema}.${table_name} (
${columns}
);
`.trim();
});
const schema = createTableStatements.join("\n\n");
return {
database: "PostgreSQL",
description: schema
};
}
async query(query) {
const result = await this.client.query(query);
const res = Array.isArray(result) ? result : [result];
return res.map((r) => r.rows);
}
}
const descriptionTemplate = (schema) => `Query a database with the following schema:
${JSON.stringify(schema)}`;
async function sqlTool(db, { notes } = {}) {
await db.initialize();
const schema = await db.describe();
schema.notes = notes;
return tool({
description: descriptionTemplate(schema),
execute: async ({ query }) => {
return await db.query(query);
},
parameters: z.object({
query: z.string().describe(
`${schema.database} Query to execute.
Notes: ${schema.notes?.join(
", "
)}`
)
})
});
}
class MySQLTool {
url;
client;
constructor(dbUrl) {
this.url = dbUrl;
}
async initialize() {
this.client = mysql$1.createConnection(this.url);
this.client.connect();
}
async describe() {
const res = await new Promise(
(resolve, reject) => this.client.query(
{
sql: `SELECT table_schema,
table_name,
GROUP_CONCAT(
CONCAT(
column_name, ' ',
CASE
WHEN data_type = 'varchar' THEN CONCAT('VARCHAR(', character_maximum_length, ')')
WHEN data_type = 'char' THEN CONCAT('CHAR(', character_maximum_length, ')')
WHEN data_type = 'decimal' THEN CONCAT('DECIMAL(', numeric_precision, ',', numeric_scale, ')')
ELSE data_type
END, ' ',
IF(is_nullable = 'YES', 'NULL', 'NOT NULL')
) ORDER BY ordinal_position
SEPARATOR '
' -- Newline separator
) AS columns
FROM information_schema.columns
WHERE table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')
GROUP BY table_schema, table_name
ORDER BY table_schema, table_name;
`
},
(err, rows, _fields) => {
if (err) {
reject(err);
}
resolve(rows);
}
)
);
const createTableStatements = res.map((row) => {
const { table_schema, table_name, columns } = row;
return `
CREATE TABLE ${table_schema}.${table_name} (
${columns}
);
`.trim();
});
const schema = createTableStatements.join("\n\n");
return {
database: "MySQL",
description: schema
};
}
async query(query) {
const res = await new Promise(
(resolve, reject) => this.client.query({ sql: query }, (err, res2) => {
if (err) {
reject(err);
}
resolve(res2);
})
);
return res;
}
}
function getDatabaseSchema(db) {
return new Promise((resolve, reject) => {
const getTablesQuery = `
SELECT name
FROM sqlite_master
WHERE type = 'table'
AND name NOT LIKE 'sqlite_%'
ORDER BY name;
`;
db.all(getTablesQuery, [], (err, tables) => {
if (err) {
return reject(`Failed to retrieve tables: ${err.message}`);
}
if (!tables.length) {
return resolve("No tables in this database.");
}
let schemaOutput = "";
let pendingTables = tables.length;
tables.forEach((name) => {
db.all(`PRAGMA table_info(${name})`, [], (err2, columns) => {
if (err2) {
return reject(
`Failed to get table_info for ${name}: ${err2.message}`
);
}
schemaOutput += `Table: ${name}
`;
columns.forEach((col) => {
let columnStr = ` ${col.name} ${col.type}`;
if (col.pk === 1) {
columnStr += " PRIMARY KEY";
}
if (col.notnull === 1) {
columnStr += " NOT NULL";
}
if (col.dflt_value !== null) {
columnStr += ` DEFAULT ${col.dflt_value}`;
}
schemaOutput += columnStr + "\n";
});
schemaOutput += "\n";
pendingTables -= 1;
if (pendingTables === 0) {
resolve(schemaOutput.trim());
}
});
});
});
});
}
class SQLiteTool {
path;
db;
constructor(dbPath) {
this.path = dbPath;
}
static inMemory() {
return new SQLiteTool(":memory:");
}
async initialize() {
this.db = new Database(this.path);
}
async describe() {
const description = await getDatabaseSchema(this.db);
return {
database: "PostgreSQL",
description: `
sqlite3 database at ${this.path}
schema:
${description}
`.trim()
};
}
async query(query) {
return await new Promise((resolve, reject) => {
this.db.run(query, (res, err) => {
if (err !== null) {
reject(err);
} else {
res.all((err2, rows) => {
if (err2 !== null) {
reject(err2);
} else {
resolve(rows);
}
});
}
});
});
}
}
async function postgres(dbUrl, options) {
return await sqlTool(new PostgresTool(dbUrl), options);
}
async function mysql(dbUrl, options) {
return await sqlTool(new MySQLTool(dbUrl), options);
}
async function sqlite(dbPath, options) {
return await sqlTool(new SQLiteTool(dbPath), options);
}
export { mysql, postgres, sqlite };
import * as mysql from "mysql2";
import { Schema, Database } from "../database";
export class MySQLTool implements Database {
private url: string;
private client: mysql.Connection;
constructor(dbUrl: string) {
this.url = dbUrl;
}
async initialize() {
this.client = mysql.createConnection(this.url);
this.client.connect();
}
async describe(): Promise<Schema> {
const res: {
table_schema: string;
table_name: string;
columns: string;
}[] = await new Promise((resolve, reject) =>
this.client.query(
{
sql: `SELECT table_schema,
table_name,
GROUP_CONCAT(
CONCAT(
column_name, ' ',
CASE
WHEN data_type = 'varchar' THEN CONCAT('VARCHAR(', character_maximum_length, ')')
WHEN data_type = 'char' THEN CONCAT('CHAR(', character_maximum_length, ')')
WHEN data_type = 'decimal' THEN CONCAT('DECIMAL(', numeric_precision, ',', numeric_scale, ')')
ELSE data_type
END, ' ',
IF(is_nullable = 'YES', 'NULL', 'NOT NULL')
) ORDER BY ordinal_position
SEPARATOR '\n' -- Newline separator
) AS columns
FROM information_schema.columns
WHERE table_schema NOT IN ('information_schema', 'mysql', 'performance_schema', 'sys')
GROUP BY table_schema, table_name
ORDER BY table_schema, table_name;
`,
},
(err, rows, _fields) => {
if (err) {
reject(err);
}
resolve(rows as any);
}
)
);
const createTableStatements = res.map((row) => {
const { table_schema, table_name, columns } = row;
// Construct the CREATE TABLE statement
return `
CREATE TABLE ${table_schema}.${table_name} (
${columns}
);
`.trim();
});
const schema = createTableStatements.join("\n\n");
return {
database: "MySQL",
description: schema,
};
}
async query(query: string): Promise<object[]> {
const res = await new Promise((resolve, reject) =>
this.client.query({ sql: query }, (err, res) => {
if (err) {
reject(err);
}
resolve(res);
})
);
return res as object[];
}
}
import { Client, QueryResult } from "pg";
import { Schema, Database } from "../database";
export class PostgresTool implements Database {
private url: string;
private client: Client;
constructor(dbUrl: string) {
this.url = dbUrl;
}
async initialize() {
this.client = new Client(this.url);
await this.client.connect();
}
async describe(): Promise<Schema> {
const result = await this.client.query(`
SELECT table_schema,
table_name,
string_agg(
format('%s %s %s', column_name,
CASE
WHEN data_type = 'character varying' THEN 'VARCHAR(' || character_maximum_length || ')'
WHEN data_type = 'numeric' THEN 'NUMERIC(' || numeric_precision || ',' || numeric_scale || ')'
WHEN data_type = 'character' THEN 'CHAR(' || character_maximum_length || ')'
ELSE data_type
END,
CASE WHEN is_nullable = 'YES' THEN 'NULL' ELSE 'NOT NULL' END),
',\n ' ORDER BY ordinal_position
) AS columns
FROM information_schema.columns
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
GROUP BY table_schema, table_name
ORDER BY table_schema, table_name;
`);
const createTableStatements = result.rows.map((row) => {
const { table_schema, table_name, columns } = row;
// Construct the CREATE TABLE statement
return `
CREATE TABLE ${table_schema}.${table_name} (
${columns}
);
`.trim();
});
const schema = createTableStatements.join("\n\n");
return {
database: "PostgreSQL",
description: schema,
};
}
async query(query: string) {
const result = await this.client.query(query);
const res = Array.isArray(result) ? result : [result];
return res.map((r: QueryResult) => r.rows);
}
}
import { RunResult, Database as SQLite } from "sqlite3";
import { Schema, Database } from "../database";
function getDatabaseSchema(db: SQLite): Promise<string> {
return new Promise((resolve, reject) => {
// Retrieve all "normal" tables (exclude SQLite's internal tables)
const getTablesQuery = `
SELECT name
FROM sqlite_master
WHERE type = 'table'
AND name NOT LIKE 'sqlite_%'
ORDER BY name;
`;
db.all(getTablesQuery, [], (err, tables) => {
if (err) {
return reject(`Failed to retrieve tables: ${err.message}`);
}
if (!tables.length) {
return resolve("No tables in this database.");
}
let schemaOutput = "";
let pendingTables = tables.length;
// Iterate over tables and get column info
tables.forEach((name) => {
db.all(`PRAGMA table_info(${name})`, [], (err, columns) => {
if (err) {
return reject(
`Failed to get table_info for ${name}: ${err.message}`,
);
}
// Add table header
schemaOutput += `Table: ${name}\n`;
// List each column with its type and other attributes
columns.forEach((col) => {
let columnStr = ` ${col.name} ${col.type}`;
if (col.pk === 1) {
columnStr += " PRIMARY KEY";
}
if (col.notnull === 1) {
columnStr += " NOT NULL";
}
if (col.dflt_value !== null) {
columnStr += ` DEFAULT ${col.dflt_value}`;
}
schemaOutput += columnStr + "\n";
});
// Add a blank line after each table
schemaOutput += "\n";
// If all tables have been processed, resolve the output
pendingTables -= 1;
if (pendingTables === 0) {
resolve(schemaOutput.trim());
}
});
});
});
});
}
export class SQLiteTool implements Database {
private path: string;
private db: SQLite;
constructor(dbPath: string) {
this.path = dbPath;
}
static inMemory() {
return new SQLiteTool(":memory:");
}
async initialize() {
this.db = new SQLite(this.path);
}
async describe(): Promise<Schema> {
const description: string = await getDatabaseSchema(this.db);
return {
database: "PostgreSQL",
description: `
sqlite3 database at ${this.path}
schema:
${description}
`.trim(),
};
}
async query(query: string): Promise<unknown[]> {
return await new Promise((resolve, reject) => {
this.db.run(query, (res: RunResult, err: Error | null) => {
if (err !== null) {
reject(err);
} else {
res.all((err, rows) => {
if (err !== null) {
reject(err);
} else {
resolve(rows as unknown[]);
}
});
}
});
});
}
}
+5
-2
{
"name": "ai-sql",
"description": "Give any Vercel AI SDK project the ability to interact with PostgreSQL, SQlite or MySQL databases in one line.",
"version": "0.0.1-4",
"version": "0.0.1-6",
"main": "./dist/index.cjs",
"module": "./dist/index.mjs",
"types": "./dist/index.d.cts",
"type": "module",
"exports": {

@@ -33,5 +34,7 @@ "require": {

"dependencies": {
"@types/pg": "^8.11.10",
"ai": "^4.0.23",
"mysql2": "^3.12.0",
"pg": "^8.13.1",
"@types/pg": "^8.11.10",
"sqlite3": "^5.1.7",
"zod": "^3.24.1"

@@ -38,0 +41,0 @@ },

@@ -7,14 +7,13 @@ # AI SQL

Note: We rely on bun for running these, but not for installing.
```ts
import * as ai from "ai";
import { createOpenAI } from "@ai-sdk/openai";
import { postgres, sqlTool } from "ai-sql"; // or mysql, sqlite
import { postgres } from "ai-sql"; // or mysql, sqlite
const openai = createOpenAI({
compatibility: "strict",
apiKey: process.env.OPENAI_API_KEY!,
});
const model = openai("gpt-4-turbo");
const { text } = await ai.generateText({

@@ -24,3 +23,3 @@ model: openai("gpt-4-turbo"),

tools: {
postgreSQL: await sqlTool(postgres(process.env.POSTGRES_URL!)),
database: await postgres(process.env.POSTGRES_URL!),
},

@@ -30,1 +29,40 @@ maxSteps: 3,

```
For more examples, see the [example](./example) directory.
## Creating a provider
```typescript
import { Schema, Database, sqlTool } from "ai-sql";
export class MyDbTool implements Database {
async initialize() {
// do setup here
}
async describe(): Promise<Schema> {
// describe the schema
return {
// database type
database: "my database",
// stringified schema representation
description: `
create table messages (
id integer primary key,
text string not null,
);
`,
};
}
async query(query: string) {
// return result rows here
return [];
}
}
const tools = {
database: await sqlTool(new MyDbTool()),
};
```

@@ -1,16 +0,8 @@

export interface ColumnDefinition {
name: string;
type: string;
primaryKey: boolean;
}
import { tool } from "ai";
import { z } from "zod";
export interface TableDefinition {
name: string;
columns: ColumnDefinition[];
}
export interface Schema {
database: string;
tables: TableDefinition[];
description: string;
notes?: string[];
}

@@ -23,3 +15,35 @@

query: (query: string) => Promise<object[]>;
query: (query: string) => Promise<unknown[]>;
}
const descriptionTemplate = (schema: unknown) =>
`Query a database with the following schema:\n\n${JSON.stringify(schema)}`;
export interface SqlToolOptions {
notes?: string[];
}
export async function sqlTool(db: Database, { notes }: SqlToolOptions = {}) {
await db.initialize();
const schema = await db.describe();
schema.notes = notes;
return tool({
description: descriptionTemplate(schema),
execute: async ({ query }) => {
return await db.query(query);
},
parameters: z.object({
query: z
.string()
.describe(
`${schema.database} Query to execute.\n\nNotes: ${schema.notes?.join(
", "
)}`
),
}),
});
}

@@ -1,56 +0,16 @@

import { tool } from "ai";
import { z } from "zod";
import { Client, QueryResult } from "pg";
import { PostgresTool } from "./providers/postgres";
import { sqlTool, SqlToolOptions } from "./database";
import { MySQLTool } from "./providers/mysql";
import { SQLiteTool } from "./providers/sqlite";
export const postgreSQLTool = async (database_url: string) => {
const client = new Client(database_url);
await client.connect();
export async function postgres(dbUrl: string, options?: SqlToolOptions) {
return await sqlTool(new PostgresTool(dbUrl), options);
}
const result = await client.query(`
SELECT table_schema,
table_name,
string_agg(
format('%s %s %s', column_name,
CASE
WHEN data_type = 'character varying' THEN 'VARCHAR(' || character_maximum_length || ')'
WHEN data_type = 'numeric' THEN 'NUMERIC(' || numeric_precision || ',' || numeric_scale || ')'
WHEN data_type = 'character' THEN 'CHAR(' || character_maximum_length || ')'
ELSE data_type
END,
CASE WHEN is_nullable = 'YES' THEN 'NULL' ELSE 'NOT NULL' END),
',\n ' ORDER BY ordinal_position
) AS columns
FROM information_schema.columns
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
GROUP BY table_schema, table_name
ORDER BY table_schema, table_name;
`);
export async function mysql(dbUrl: string, options?: SqlToolOptions) {
return await sqlTool(new MySQLTool(dbUrl), options);
}
const createTableStatements = result.rows.map((row) => {
const { table_schema, table_name, columns } = row;
// Construct the CREATE TABLE statement
return `
CREATE TABLE ${table_schema}.${table_name} (
${columns}
);
`.trim();
});
let schema = createTableStatements.join("\n\n");
return tool({
description: `Query PostgreSQL Database with the following schema:\n\n${schema}`,
execute: async (query) => {
const result = await client.query(query.query!);
const res = Array.isArray(result) ? result : [result];
const rows = res.map((r: QueryResult) => r.rows);
return rows;
},
parameters: z.object({
query: z.string().describe("SQL Query to execute"),
}),
});
};
export async function sqlite(dbPath: string, options?: SqlToolOptions) {
return await sqlTool(new SQLiteTool(dbPath), options);
}

Sorry, the diff of this file is not supported yet

{
"name": "example",
"version": "1.0.0",
"main": "index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"keywords": [],
"author": "",
"license": "ISC",
"description": "",
"dependencies": {
"@ai-sdk/openai": "^1.0.12",
"ai": "^4.0.23",
"ai-sql": "file:..",
"dotenv": "^16.4.7"
}
}
import { postgreSQLTool } from "../../src";
import * as dotenv from "dotenv";
import * as ai from "ai";
dotenv.config();
import { createOpenAI } from "@ai-sdk/openai";
const openai = createOpenAI({
compatibility: "strict",
apiKey: process.env.OPENAI_API_KEY!,
});
const params = {
model: openai("gpt-4-turbo"),
maxSteps: 3,
};
const { text: text1 } = await ai.generateText({
...params,
tools: {
postgreSQL: await postgreSQLTool(process.env.POSTGRES_URL!),
},
prompt:
"Create an employees table and an hours table to track work, add some example empoyees and hours (about 1 month worth of entries for each employee), and return all employees",
});
console.log("OUTPUT", text1);
const { text: text2 } = await ai.generateText({
...params,
tools: {
postgreSQL: await postgreSQLTool(process.env.POSTGRES_URL!),
},
prompt:
"How many employees worked more than 10 hours during any week of the month?",
});
console.log("OUTPUT", text2);