@zenstackhq/runtime
Advanced tools
@@ -82,4 +82,10 @@ /** | ||
/** | ||
* Mapping from foreign key field names to relation field names | ||
* If the field is a foreign key field, the field name of the corresponding relation field. | ||
* Only available on foreign key fields. | ||
*/ | ||
relationField?: string; | ||
/** | ||
* Mapping from foreign key field names to relation field names. | ||
* Only available on relation fields. | ||
*/ | ||
foreignKeyMapping?: Record<string, string>; | ||
@@ -86,0 +92,0 @@ /** |
@@ -21,2 +21,3 @@ "use strict"; | ||
const proxy_1 = require("./proxy"); | ||
const utils_1 = require("./utils"); | ||
/** | ||
@@ -27,3 +28,3 @@ * Gets an enhanced Prisma client that supports `@default(auth())` attribute. | ||
*/ | ||
function withDefaultAuth(prisma, options, context) { | ||
function withDefaultAuth(prisma, options, context = {}) { | ||
return (0, proxy_1.makeProxy)(prisma, options.modelMeta, (_prisma, model) => new DefaultAuthHandler(_prisma, model, options, context), 'defaultAuth'); | ||
@@ -34,8 +35,4 @@ } | ||
constructor(prisma, model, options, context) { | ||
var _a; | ||
super(prisma, model, options); | ||
this.context = context; | ||
if (!((_a = this.context) === null || _a === void 0 ? void 0 : _a.user)) { | ||
throw new Error(`Using \`auth()\` in \`@default\` requires a user context`); | ||
} | ||
this.userContext = this.context.user; | ||
@@ -71,3 +68,3 @@ } | ||
// set field value extracted from `auth()` | ||
data[fieldInfo.name] = authDefaultValue; | ||
this.setAuthDefaultValue(fieldInfo, model, data, authDefaultValue); | ||
} | ||
@@ -91,4 +88,42 @@ } | ||
} | ||
setAuthDefaultValue(fieldInfo, model, data, authDefaultValue) { | ||
var _a; | ||
if (fieldInfo.isForeignKey && !(0, utils_1.isUnsafeMutate)(model, data, this.options.modelMeta)) { | ||
// if the field is a fk, and the create payload is not unsafe, we need to translate | ||
// the fk field setting to a `connect` of the corresponding relation field | ||
const relFieldName = fieldInfo.relationField; | ||
if (!relFieldName) { | ||
throw new Error(`Field \`${fieldInfo.name}\` is a foreign key field but no corresponding relation field is found`); | ||
} | ||
const relationField = (0, cross_1.requireField)(this.options.modelMeta, model, relFieldName); | ||
// construct a `{ connect: { ... } }` payload | ||
let connect = (_a = data[relationField.name]) === null || _a === void 0 ? void 0 : _a.connect; | ||
if (!connect) { | ||
connect = {}; | ||
data[relationField.name] = { connect }; | ||
} | ||
// sets the opposite fk field to value `authDefaultValue` | ||
const oppositeFkFieldName = this.getOppositeFkFieldName(relationField, fieldInfo); | ||
if (!oppositeFkFieldName) { | ||
throw new Error(`Cannot find opposite foreign key field for \`${fieldInfo.name}\` in relation field \`${relFieldName}\``); | ||
} | ||
connect[oppositeFkFieldName] = authDefaultValue; | ||
} | ||
else { | ||
// set default value directly | ||
data[fieldInfo.name] = authDefaultValue; | ||
} | ||
} | ||
getOppositeFkFieldName(relationField, fieldInfo) { | ||
if (!relationField.foreignKeyMapping) { | ||
return undefined; | ||
} | ||
const entry = Object.entries(relationField.foreignKeyMapping).find(([, v]) => v === fieldInfo.name); | ||
return entry === null || entry === void 0 ? void 0 : entry[0]; | ||
} | ||
getDefaultValueFromAuth(fieldInfo) { | ||
var _a; | ||
if (!this.userContext) { | ||
throw new Error(`Evaluating default value of field \`${fieldInfo.name}\` requires a user context`); | ||
} | ||
return (_a = fieldInfo.defaultValueProvider) === null || _a === void 0 ? void 0 : _a.call(fieldInfo, this.userContext); | ||
@@ -95,0 +130,0 @@ } |
@@ -63,2 +63,6 @@ "use strict"; | ||
this.injectSelectIncludeHierarchy(model, args); | ||
if (args.orderBy) { | ||
// `orderBy` may contain fields from base types | ||
args.orderBy = this.buildWhereHierarchy(this.model, args.orderBy); | ||
} | ||
if (this.options.logPrismaQuery) { | ||
@@ -109,3 +113,3 @@ this.logger.info(`[delegate] \`${method}\` ${this.getModelName(model)}: ${(0, utils_1.formatObject)(args)}`); | ||
} | ||
buildWhereHierarchy(where) { | ||
buildWhereHierarchy(model, where) { | ||
if (!where) { | ||
@@ -116,7 +120,7 @@ return undefined; | ||
Object.entries(where).forEach(([field, value]) => { | ||
const fieldInfo = (0, cross_1.resolveField)(this.options.modelMeta, this.model, field); | ||
const fieldInfo = (0, cross_1.resolveField)(this.options.modelMeta, model, field); | ||
if (!(fieldInfo === null || fieldInfo === void 0 ? void 0 : fieldInfo.inheritedFrom)) { | ||
return; | ||
} | ||
let base = this.getBaseModel(this.model); | ||
let base = this.getBaseModel(model); | ||
let target = where; | ||
@@ -153,3 +157,8 @@ while (base) { | ||
for (const [field, value] of Object.entries(args[kind])) { | ||
if (value !== undefined) { | ||
const fieldInfo = (0, cross_1.resolveField)(this.options.modelMeta, model, field); | ||
if (fieldInfo && value !== undefined) { | ||
if (value === null || value === void 0 ? void 0 : value.orderBy) { | ||
// `orderBy` may contain fields from base types | ||
value.orderBy = this.buildWhereHierarchy(fieldInfo.type, value.orderBy); | ||
} | ||
if (this.injectBaseFieldSelect(model, field, value, args, kind)) { | ||
@@ -159,3 +168,2 @@ delete args[kind][field]; | ||
else { | ||
const fieldInfo = (0, cross_1.resolveField)(this.options.modelMeta, model, field); | ||
if (fieldInfo && this.isDelegateOrDescendantOfDelegate(fieldInfo.type)) { | ||
@@ -685,9 +693,9 @@ let nextValue = value; | ||
if (args.cursor) { | ||
args.cursor = this.buildWhereHierarchy(args.cursor); | ||
args.cursor = this.buildWhereHierarchy(this.model, args.cursor); | ||
} | ||
if (args.orderBy) { | ||
args.orderBy = this.buildWhereHierarchy(args.orderBy); | ||
args.orderBy = this.buildWhereHierarchy(this.model, args.orderBy); | ||
} | ||
if (args.where) { | ||
args.where = this.buildWhereHierarchy(args.where); | ||
args.where = this.buildWhereHierarchy(this.model, args.where); | ||
} | ||
@@ -707,6 +715,6 @@ if (this.options.logPrismaQuery) { | ||
if (args === null || args === void 0 ? void 0 : args.cursor) { | ||
args.cursor = this.buildWhereHierarchy(args.cursor); | ||
args.cursor = this.buildWhereHierarchy(this.model, args.cursor); | ||
} | ||
if (args === null || args === void 0 ? void 0 : args.where) { | ||
args.where = this.buildWhereHierarchy(args.where); | ||
args.where = this.buildWhereHierarchy(this.model, args.where); | ||
} | ||
@@ -737,3 +745,3 @@ if (this.options.logPrismaQuery) { | ||
if (args.where) { | ||
args.where = this.buildWhereHierarchy(args.where); | ||
args.where = this.buildWhereHierarchy(this.model, args.where); | ||
} | ||
@@ -740,0 +748,0 @@ if (this.options.logPrismaQuery) { |
@@ -42,4 +42,3 @@ import { type DbClientContract } from '../../types'; | ||
private doUpdate; | ||
private isUnsafeMutate; | ||
private isAutoIncrementIdField; | ||
private validateUpdateInputSchema; | ||
updateMany(args: any): Promise<{ | ||
@@ -46,0 +45,0 @@ count: number; |
@@ -402,3 +402,3 @@ "use strict"; | ||
const schema = this.policyUtils.getZodSchema(model, 'create'); | ||
if (schema) { | ||
if (schema && data) { | ||
const parseResult = schema.safeParse(data); | ||
@@ -424,14 +424,17 @@ if (!parseResult.success) { | ||
args = this.policyUtils.clone(args); | ||
// do static input validation and check if post-create checks are needed | ||
// go through create items, statically check input to determine if post-create | ||
// check is needed, and also validate zod schema | ||
let needPostCreateCheck = false; | ||
for (const item of (0, cross_1.enumerate)(args.data)) { | ||
const validationResult = this.validateCreateInputSchema(this.model, item); | ||
if (validationResult !== item) { | ||
this.policyUtils.replace(item, validationResult); | ||
} | ||
const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create'); | ||
if (inputCheck === false) { | ||
// unconditionally deny | ||
throw this.policyUtils.deniedByPolicy(this.model, 'create', undefined, constants_1.CrudFailureReason.ACCESS_POLICY_VIOLATION); | ||
} | ||
else if (inputCheck === true) { | ||
const r = this.validateCreateInputSchema(this.model, item); | ||
if (r !== item) { | ||
this.policyUtils.replace(item, r); | ||
} | ||
// unconditionally allow | ||
} | ||
@@ -441,3 +444,2 @@ else if (inputCheck === undefined) { | ||
needPostCreateCheck = true; | ||
break; | ||
} | ||
@@ -577,3 +579,3 @@ } | ||
// - unsafe: { data: { userId: 1 } } | ||
const unsafe = this.isUnsafeMutate(model, args); | ||
const unsafe = (0, utils_1.isUnsafeMutate)(model, args, this.modelMeta); | ||
// handles the connection to upstream entity | ||
@@ -675,2 +677,6 @@ const reversedQuery = this.policyUtils.buildReversedQuery(context, true, unsafe); | ||
const updatePayload = (_c = args.data) !== null && _c !== void 0 ? _c : args; | ||
const validatedPayload = this.validateUpdateInputSchema(model, updatePayload); | ||
if (validatedPayload !== updatePayload) { | ||
this.policyUtils.replace(updatePayload, validatedPayload); | ||
} | ||
if (updatePayload) { | ||
@@ -733,2 +739,3 @@ for (const key of Object.keys(updatePayload)) { | ||
} | ||
args.data = this.validateUpdateInputSchema(model, args.data); | ||
const updateGuard = this.policyUtils.getAuthGuard(db, model, 'update'); | ||
@@ -782,3 +789,6 @@ if (this.policyUtils.isTrue(updateGuard) || this.policyUtils.isFalse(updateGuard)) { | ||
// convert upsert to update | ||
context.parent.update = { where: args.where, data: args.update }; | ||
context.parent.update = { | ||
where: args.where, | ||
data: this.validateUpdateInputSchema(model, args.update), | ||
}; | ||
delete context.parent.upsert; | ||
@@ -868,17 +878,21 @@ // continue visiting the new payload | ||
} | ||
isUnsafeMutate(model, args) { | ||
if (!args) { | ||
return false; | ||
} | ||
for (const k of Object.keys(args)) { | ||
const field = (0, cross_1.resolveField)(this.modelMeta, model, k); | ||
if (field && (this.isAutoIncrementIdField(field) || field.isForeignKey)) { | ||
return true; | ||
// Validates the given update payload against Zod schema if any | ||
validateUpdateInputSchema(model, data) { | ||
const schema = this.policyUtils.getZodSchema(model, 'update'); | ||
if (schema && data) { | ||
// update payload can contain non-literal fields, like: | ||
// { x: { increment: 1 } } | ||
// we should only validate literal fields | ||
const literalData = Object.entries(data).reduce((acc, [k, v]) => (Object.assign(Object.assign({}, acc), (typeof v !== 'object' ? { [k]: v } : {}))), {}); | ||
const parseResult = schema.safeParse(literalData); | ||
if (!parseResult.success) { | ||
throw this.policyUtils.deniedByPolicy(model, 'update', `input failed validation: ${(0, zod_validation_error_1.fromZodError)(parseResult.error)}`, constants_1.CrudFailureReason.DATA_VALIDATION_VIOLATION, parseResult.error); | ||
} | ||
// schema may have transformed field values, use it to overwrite the original data | ||
return Object.assign(Object.assign({}, data), parseResult.data); | ||
} | ||
return false; | ||
else { | ||
return data; | ||
} | ||
} | ||
isAutoIncrementIdField(field) { | ||
return field.isId && field.isAutoIncrement; | ||
} | ||
updateMany(args) { | ||
@@ -895,2 +909,3 @@ return __awaiter(this, void 0, void 0, function* () { | ||
this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update'); | ||
args.data = this.validateUpdateInputSchema(this.model, args.data); | ||
if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) { | ||
@@ -897,0 +912,0 @@ // use a transaction to do post-update checks |
@@ -60,2 +60,6 @@ import { ZodError } from 'zod'; | ||
/** | ||
* Checks if the given model has any field-level override policy guard for the given operation. | ||
*/ | ||
hasOverrideAuthGuard(model: string, operation: PolicyOperationKind): boolean; | ||
/** | ||
* Checks model creation policy based on static analysis to the input args. | ||
@@ -62,0 +66,0 @@ * |
@@ -290,2 +290,16 @@ "use strict"; | ||
/** | ||
* Checks if the given model has any field-level override policy guard for the given operation. | ||
*/ | ||
hasOverrideAuthGuard(model, operation) { | ||
const guard = this.requireGuard(model); | ||
switch (operation) { | ||
case 'read': | ||
return Object.keys(guard).some((k) => k.startsWith(constants_1.FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX)); | ||
case 'update': | ||
return Object.keys(guard).some((k) => k.startsWith(constants_1.FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX)); | ||
default: | ||
return false; | ||
} | ||
} | ||
/** | ||
* Checks model creation policy based on static analysis to the input args. | ||
@@ -542,3 +556,3 @@ * | ||
let guard = this.getAuthGuard(db, model, operation, preValue); | ||
if (this.isFalse(guard)) { | ||
if (this.isFalse(guard) && !this.hasOverrideAuthGuard(model, operation)) { | ||
throw this.deniedByPolicy(model, operation, `entity ${(0, utils_1.formatObject)(uniqueFilter)} failed policy check`, constants_1.CrudFailureReason.ACCESS_POLICY_VIOLATION); | ||
@@ -672,3 +686,3 @@ } | ||
const guard = this.getAuthGuard(db, model, operation); | ||
if (this.isFalse(guard)) { | ||
if (this.isFalse(guard) && !this.hasOverrideAuthGuard(model, operation)) { | ||
throw this.deniedByPolicy(model, operation, undefined, constants_1.CrudFailureReason.ACCESS_POLICY_VIOLATION); | ||
@@ -675,0 +689,0 @@ } |
@@ -0,1 +1,2 @@ | ||
import { FieldInfo, ModelMeta } from '..'; | ||
import type { DbClientContract } from '../types'; | ||
@@ -9,1 +10,3 @@ /** | ||
export declare function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error; | ||
export declare function isUnsafeMutate(model: string, args: any, modelMeta: ModelMeta): boolean; | ||
export declare function isAutoIncrementIdField(field: FieldInfo): boolean | undefined; |
@@ -26,4 +26,5 @@ "use strict"; | ||
Object.defineProperty(exports, "__esModule", { value: true }); | ||
exports.prismaClientUnknownRequestError = exports.prismaClientKnownRequestError = exports.prismaClientValidationError = exports.formatObject = void 0; | ||
exports.isAutoIncrementIdField = exports.isUnsafeMutate = exports.prismaClientUnknownRequestError = exports.prismaClientKnownRequestError = exports.prismaClientValidationError = exports.formatObject = void 0; | ||
const util = __importStar(require("util")); | ||
const __1 = require(".."); | ||
/** | ||
@@ -51,2 +52,20 @@ * Formats an object for pretty printing. | ||
exports.prismaClientUnknownRequestError = prismaClientUnknownRequestError; | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
function isUnsafeMutate(model, args, modelMeta) { | ||
if (!args) { | ||
return false; | ||
} | ||
for (const k of Object.keys(args)) { | ||
const field = (0, __1.resolveField)(modelMeta, model, k); | ||
if (field && (isAutoIncrementIdField(field) || field.isForeignKey)) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
exports.isUnsafeMutate = isUnsafeMutate; | ||
function isAutoIncrementIdField(field) { | ||
return field.isId && field.isAutoIncrement; | ||
} | ||
exports.isAutoIncrementIdField = isAutoIncrementIdField; | ||
//# sourceMappingURL=utils.js.map |
{ | ||
"name": "@zenstackhq/runtime", | ||
"displayName": "ZenStack Runtime Library", | ||
"version": "2.0.0-alpha.1", | ||
"version": "2.0.0-alpha.2", | ||
"description": "Runtime of ZenStack for both client-side and server-side environments.", | ||
@@ -42,2 +42,5 @@ "repository": { | ||
"default": "./cross/index.js" | ||
}, | ||
"./prisma": { | ||
"types": "./prisma.d.ts" | ||
} | ||
@@ -44,0 +47,0 @@ }, |
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is not supported yet
563165
1.88%92
1.1%7044
1.5%