diff --git a/packages/typegpu/src/tgsl/wgslGenerator.ts b/packages/typegpu/src/tgsl/wgslGenerator.ts index f8e9fc256e..3ee8d020d9 100644 --- a/packages/typegpu/src/tgsl/wgslGenerator.ts +++ b/packages/typegpu/src/tgsl/wgslGenerator.ts @@ -19,6 +19,8 @@ import { $gpuCallable, $internal, $providing, isMarkedInternal } from '../shared import { safeStringify } from '../shared/stringify.ts'; import { pow } from '../std/numeric.ts'; import { add, div, mul, neg, sub } from '../std/operators.ts'; +import { eq, ne, lt, le, gt, ge } from '../std/boolean.ts'; + import { isGPUCallable, isKnownAtComptime, @@ -85,6 +87,14 @@ const parenthesizedOps = [ ]; const binaryLogicalOps = ['&&', '||', '==', '!=', '===', '!==', '<', '<=', '>', '>=']; +const binaryRelationalOpToStdMap: Record = { + '===': eq.toString(), + '!==': ne.toString(), + '<': lt.toString(), + '<=': le.toString(), + '>': gt.toString(), + '>=': ge.toString(), +}; const bitShiftOps: string[] = ['<<', '>>', '<<=', '>>=']; @@ -383,24 +393,26 @@ ${this.ctx.pre}}`; throw new Error('Please use the !== operator instead of !='); } - if (op === '===' && isKnownAtComptime(lhsExpr) && isKnownAtComptime(rhsExpr)) { - return snip(lhsExpr.value === rhsExpr.value, bool, 'constant', false); - } - - if (op === '!==' && isKnownAtComptime(lhsExpr) && isKnownAtComptime(rhsExpr)) { - return snip(lhsExpr.value !== rhsExpr.value, bool, 'constant', false); - } - - if ( - (op === '<' || op === '<=' || op === '>' || op === '>=') && - isKnownAtComptime(lhsExpr) && - isKnownAtComptime(rhsExpr) - ) { + const stdBinaryRelationalOp = binaryRelationalOpToStdMap[op]; + if (stdBinaryRelationalOp && isKnownAtComptime(lhsExpr) && isKnownAtComptime(rhsExpr)) { const left = lhsExpr.value; const right = rhsExpr.value; + + switch (op) { + case '===': + return snip(left === right, bool, 'constant', false); + case '!==': + return snip(left !== right, bool, 'constant', false); + } + if (typeof left !== 'number' || typeof right !== 'number') { + const bothVectors = wgsl.isVec(lhsExpr.dataType) && wgsl.isVec(rhsExpr.dataType); throw new WgslTypeError( - `Inequality comparison '${op}' requires numeric operands, got '${typeof left}' and '${typeof right}'`, + `Comparison '${op}' requires numeric operands.${ + bothVectors + ? ` For component-wise comparison, use 'std.${stdBinaryRelationalOp}''.` + : '' + }`, ); } @@ -469,6 +481,33 @@ ${this.ctx.pre}}`; } } + if (stdBinaryRelationalOp) { + const equalityCheck = ['===', '!=='].includes(op); + const correctOperandTypes = + (wgsl.isNumericSchema(convLhs.dataType) && wgsl.isNumericSchema(convRhs.dataType)) || + (equalityCheck && wgsl.isBool(convLhs.dataType) && wgsl.isBool(convRhs.dataType)); + + if (!correctOperandTypes) { + const bothVectors = wgsl.isVec(convLhs.dataType) && wgsl.isVec(convRhs.dataType); + throw new WgslTypeError( + `Comparison '${op}' requires numeric${equalityCheck ? ' or boolean' : ''} operands. Got '${String(convLhs.dataType)}' and '${String(convRhs.dataType)}'.${ + bothVectors + ? ` For component-wise comparison, use 'std.${stdBinaryRelationalOp}'.` + : '' + }`, + ); + } + } + + if ( + (op === '&&' || op === '||') && + !(wgsl.isBool(convLhs.dataType) && wgsl.isBool(convRhs.dataType)) + ) { + throw new WgslTypeError( + `Logical expression '${op}' requires boolean operands. Got '${String(convLhs.dataType)}' and '${String(convRhs.dataType)}'.`, + ); + } + return snip( parenthesizedOps.includes(op) ? `(${lhsStr} ${OP_MAP[op] ?? op} ${rhsStr})` @@ -859,6 +898,13 @@ ${this.ctx.pre}}`; if (isKnownAtComptime(test)) { return test.value ? this._expression(consequentNode) : this._expression(alternativeNode); } else { + const convertedTest = tryConvertSnippet(this.ctx, test, bool, false); + if (!convertedTest) { + throw new Error( + `Ternary operator '${stringifyNode(expression)}' is invalid. Cannot convert condition to bool.`, + ); + } + const consequent = this._expression(consequentNode); const alternative = this._expression(alternativeNode); const [con, alt] = @@ -871,11 +917,11 @@ ${this.ctx.pre}}`; } return snip( - stitch`select(${alt}, ${con}, ${test})`, + stitch`select(${alt}, ${con}, ${convertedTest})`, con.dataType, 'runtime', // this select has side-effects only if the condition has side-effects - test.possibleSideEffects, + convertedTest.possibleSideEffects, ); } } diff --git a/packages/typegpu/tests/tgsl/binaryLogicalOps.test.ts b/packages/typegpu/tests/tgsl/binaryLogicalOps.test.ts new file mode 100644 index 0000000000..2eb6126826 --- /dev/null +++ b/packages/typegpu/tests/tgsl/binaryLogicalOps.test.ts @@ -0,0 +1,547 @@ +import { expect, describe, beforeEach } from 'vitest'; +import { it } from 'typegpu-testing-utility'; +import tgpu, { d } from 'typegpu'; + +describe('binaryLogicalOps', () => { + const Boid = d.struct({ pos: d.vec3f }); + const BoidOnSteroids = d.struct({ pos: d.vec3f, strength: d.f32 }); + + describe('relational', () => { + describe('comptime', () => { + it('handles numeric', () => { + const x = 7 as number; + const y = 8 as number; + + const f = () => { + 'use gpu'; + let r = true; + r = x === y; + r = x !== y; + r = x < y; + r = x <= y; + r = x > y; + r = x >= y; + }; + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f() { + var r = true; + r = false; + r = true; + r = true; + r = true; + r = false; + r = false; + }" + `); + }); + + it('equality comparison handles non numeric operands', () => { + const x = Boid(); + const y = BoidOnSteroids(); + + const eq = () => { + 'use gpu'; + const _r = x === y; + }; + const ne = () => { + 'use gpu'; + const _r = x !== y; + }; + + expect(tgpu.resolve([eq])).toMatchInlineSnapshot(` + "fn eq() { + const _r = false; + }" + `); + expect(tgpu.resolve([ne])).toMatchInlineSnapshot(` + "fn ne() { + const _r = true; + }" + `); + }); + + it('throws when both operands are not numeric', () => { + const x = Boid(); + const y = BoidOnSteroids(); + + const eq = () => { + 'use gpu'; + const _r = x === y; + }; + const ne = () => { + 'use gpu'; + const _r = x !== y; + }; + const lt = () => { + 'use gpu'; + const _r = x < y; + }; + const le = () => { + 'use gpu'; + const _r = x <= y; + }; + const gt = () => { + 'use gpu'; + const _r = x > y; + }; + const ge = () => { + 'use gpu'; + const _r = x >= y; + }; + + expect(() => tgpu.resolve([lt])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn*:lt + - fn*:lt(): Comparison '<' requires numeric operands.] + `); + expect(() => tgpu.resolve([le])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn*:le + - fn*:le(): Comparison '<=' requires numeric operands.] + `); + expect(() => tgpu.resolve([gt])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn*:gt + - fn*:gt(): Comparison '>' requires numeric operands.] + `); + expect(() => tgpu.resolve([ge])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn*:ge + - fn*:ge(): Comparison '>=' requires numeric operands.] + `); + }); + + it('when both operands are vectors suggests std function', () => { + const x = d.vec3f(); + const y = x; + + const f = () => { + 'use gpu'; + return x >= y; + }; + + expect(() => tgpu.resolve([f])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn*:f + - fn*:f(): Comparison '>=' requires numeric operands. For component-wise comparison, use 'std.ge''.] + `); + }); + }); + + describe('runtime', () => { + it('handles numeric', () => { + const x = 7 as number; + + const f = tgpu.fn([d.i32])((y) => { + 'use gpu'; + let r = true; + r = x === y; + r = x !== y; + r = x < y; + r = x <= y; + r = x > y; + r = x >= y; + }); + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f(y: i32) { + var r = true; + r = (7i == y); + r = (7i != y); + r = (7i < y); + r = (7i <= y); + r = (7i > y); + r = (7i >= y); + }" + `); + }); + + it('equality comparison handles boolean operands', () => { + const a = false; + const cAccessor = tgpu.accessor(d.bool, () => true); + const f = tgpu.fn([d.bool])((b) => { + 'use gpu'; + let r = true; + r = cAccessor.$ === b; + r = a !== b; + }); + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f(b: bool) { + var r = true; + r = (true == b); + r = (false != b); + }" + `); + }); + + it('throws when both operands are not numeric', () => { + const xAccessor = tgpu.accessor(Boid, () => Boid()); + + const eq = tgpu.fn([BoidOnSteroids])((y) => { + 'use gpu'; + const _r = xAccessor.$ === Boid(y); + }); + const ne = tgpu.fn([BoidOnSteroids])((y) => { + 'use gpu'; + const _r = xAccessor.$ !== Boid(y); + }); + const lt = tgpu.fn([BoidOnSteroids])((y) => { + 'use gpu'; + const _r = xAccessor.$ < Boid(y); + }); + const le = tgpu.fn([BoidOnSteroids])((y) => { + 'use gpu'; + const _r = xAccessor.$ <= Boid(y); + }); + const gt = tgpu.fn([BoidOnSteroids])((y) => { + 'use gpu'; + const _r = xAccessor.$ > Boid(y); + }); + const ge = tgpu.fn([BoidOnSteroids])((y) => { + 'use gpu'; + const _r = xAccessor.$ >= Boid(y); + }); + + expect(() => tgpu.resolve([eq])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn:eq: Comparison '===' requires numeric or boolean operands. Got 'struct:Boid' and 'struct:Boid'.] + `); + expect(() => tgpu.resolve([ne])).toThrowErrorMatchingInlineSnapshot( + ` + [Error: Resolution of the following tree failed: + - + - fn:ne: Comparison '!==' requires numeric or boolean operands. Got 'struct:Boid' and 'struct:Boid'.] + `, + ); + expect(() => tgpu.resolve([lt])).toThrowErrorMatchingInlineSnapshot( + ` + [Error: Resolution of the following tree failed: + - + - fn:lt: Comparison '<' requires numeric operands. Got 'struct:Boid' and 'struct:Boid'.] + `, + ); + expect(() => tgpu.resolve([le])).toThrowErrorMatchingInlineSnapshot( + ` + [Error: Resolution of the following tree failed: + - + - fn:le: Comparison '<=' requires numeric operands. Got 'struct:Boid' and 'struct:Boid'.] + `, + ); + expect(() => tgpu.resolve([gt])).toThrowErrorMatchingInlineSnapshot( + ` + [Error: Resolution of the following tree failed: + - + - fn:gt: Comparison '>' requires numeric operands. Got 'struct:Boid' and 'struct:Boid'.] + `, + ); + expect(() => tgpu.resolve([ge])).toThrowErrorMatchingInlineSnapshot( + ` + [Error: Resolution of the following tree failed: + - + - fn:ge: Comparison '>=' requires numeric operands. Got 'struct:Boid' and 'struct:Boid'.] + `, + ); + }); + + it('when both operands are vectors suggests std function', () => { + const x = d.vec3f(); + + const f = tgpu.fn([d.vec3f])((y) => { + 'use gpu'; + return x === y; + }); + + expect(() => tgpu.resolve([f])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn:f: Comparison '===' requires numeric or boolean operands. Got 'vec3f' and 'vec3f'. For component-wise comparison, use 'std.eq'.] + `); + }); + }); + }); + + describe('operator &&', () => { + it('handles boolean operands', () => { + const and = tgpu.fn( + [d.bool, d.bool], + d.bool, + )((x, y) => { + 'use gpu'; + return x && y; + }); + + expect(tgpu.resolve([and])).toMatchInlineSnapshot(` + "fn and(x: bool, y: bool) -> bool { + return (x && y); + }" + `); + }); + + it('throws when both operands are not boolean', () => { + const and = tgpu.fn( + [d.u32, Boid], + d.bool, + )((x, y) => { + 'use gpu'; + return !!(x && y); + }); + + expect(() => tgpu.resolve([and])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn:and: Logical expression '&&' requires boolean operands. Got 'u32' and 'struct:Boid'.] + `); + }); + }); + + describe('operator ||', () => { + it('handles boolean operands', () => { + const or = tgpu.fn( + [d.bool, d.bool], + d.bool, + )((x, y) => { + 'use gpu'; + return x || y; + }); + + expect(tgpu.resolve([or])).toMatchInlineSnapshot(` + "fn or(x: bool, y: bool) -> bool { + return (x || y); + }" + `); + }); + + it('throws when both operands are not boolean', () => { + const or = tgpu.fn( + [d.u32, Boid], + d.bool, + )((x, y) => { + 'use gpu'; + return !!(x || y); + }); + + expect(() => tgpu.resolve([or])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn:or: Logical expression '||' requires boolean operands. Got 'u32' and 'struct:Boid'.] + `); + }); + }); + + describe('short-circuit evaluation', () => { + const state = { counter: 0, result: true }; + const getTrackedBool = tgpu.comptime(() => { + state.counter++; + return state.result; + }); + beforeEach(() => { + state.counter = 0; + state.result = true; + }); + + it('handles ||', () => { + const f = () => { + 'use gpu'; + let res = -1; + // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test + if (true || getTrackedBool()) { + res = 1; + } + return res; + }; + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f() -> i32 { + var res = -1; + { + res = 1i; + } + return res; + }" + `); + expect(state.counter).toBe(0); + }); + + it('handles chained ||', () => { + state.result = false; + + const f = () => { + 'use gpu'; + let res = -1; + // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test + if (getTrackedBool() || true || getTrackedBool() || getTrackedBool() || getTrackedBool()) { + res = 1; + } + return res; + }; + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f() -> i32 { + var res = -1; + { + res = 1i; + } + return res; + }" + `); + expect(state.counter).toEqual(1); + }); + + it('skips false lhs', () => { + const f = tgpu.fn( + [d.bool], + d.i32, + )((b) => { + 'use gpu'; + let res = -1; + // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test + if (false || b) { + res = 1; + } + return res; + }); + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f(b: bool) -> i32 { + var res = -1; + if (b) { + res = 1i; + } + return res; + }" + `); + }); + + it('throws when rhs cannot be converted to boolean', () => { + const b = false; + const f = tgpu.fn( + [d.vec3f], + d.bool, + )((v) => { + 'use gpu'; + + return !!(b || v); + }); + + expect(() => tgpu.resolve([f])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn:f: Cannot convert value of type 'vec3f' to any of the target types: [bool]] + `); + }); + + it('handles &&', () => { + const f = () => { + 'use gpu'; + let res = -1; + // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test + if (false && getTrackedBool()) { + res = 1; + } + return res; + }; + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f() -> i32 { + let res = -1; + return res; + }" + `); + expect(state.counter).toBe(0); + }); + + it('handles chained &&', () => { + const f = () => { + 'use gpu'; + let res = -1; + // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test + if (getTrackedBool() && false && getTrackedBool() && getTrackedBool() && getTrackedBool()) { + res = 1; + } + return res; + }; + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f() -> i32 { + let res = -1; + return res; + }" + `); + expect(state.counter).toBe(1); + }); + + it('skips true lhs', () => { + const f = tgpu.fn( + [d.bool], + d.i32, + )((b) => { + 'use gpu'; + let res = -1; + // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test + if (true && b) { + res = 1; + } + return res; + }); + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f(b: bool) -> i32 { + var res = -1; + if (b) { + res = 1i; + } + return res; + }" + `); + }); + + it('throws when rhs cannot be converted to boolean', () => { + const b = true; + const f = tgpu.fn( + [d.vec3f], + d.bool, + )((v) => { + 'use gpu'; + + return !!(b && v); + }); + + expect(() => tgpu.resolve([f])).toThrowErrorMatchingInlineSnapshot(` + [Error: Resolution of the following tree failed: + - + - fn:f: Cannot convert value of type 'vec3f' to any of the target types: [bool]] + `); + }); + + it('handles mixed operators', () => { + const f = () => { + 'use gpu'; + let res = -1; + // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test + if (true || (getTrackedBool() && getTrackedBool())) { + res = 1; + } + return res; + }; + + expect(tgpu.resolve([f])).toMatchInlineSnapshot(` + "fn f() -> i32 { + var res = -1; + { + res = 1i; + } + return res; + }" + `); + expect(state.counter).toBe(0); + }); + }); +}); diff --git a/packages/typegpu/tests/tgsl/ternaryOperator.test.ts b/packages/typegpu/tests/tgsl/ternaryOperator.test.ts index 651fccb0dc..cf561534c4 100644 --- a/packages/typegpu/tests/tgsl/ternaryOperator.test.ts +++ b/packages/typegpu/tests/tgsl/ternaryOperator.test.ts @@ -181,18 +181,18 @@ describe('ternary operator', () => { `); }); - it('should throw when test is not comptime known', () => { + it('should throw when test cannot be converted to bool', () => { const myFn = tgpu.fn( - [d.u32], + [d.vec3f, d.u32], d.u32, - )((n) => { - return n > 0 ? n : -n; + )((v, n) => { + return v ? n : n + 1; }); expect(() => tgpu.resolve([myFn])).toThrowErrorMatchingInlineSnapshot(` [Error: Resolution of the following tree failed: - - - fn:myFn: Ternary operator '(n > 0) ? n : (-n)' is invalid. For more complex branching, please use 'std.select' or if/else statements.] + - fn:myFn: Cannot convert value of type 'vec3f' to any of the target types: [bool]] `); }); }); diff --git a/packages/typegpu/tests/tgsl/wgslGenerator.test.ts b/packages/typegpu/tests/tgsl/wgslGenerator.test.ts index 56f705c668..cf90e382d8 100644 --- a/packages/typegpu/tests/tgsl/wgslGenerator.test.ts +++ b/packages/typegpu/tests/tgsl/wgslGenerator.test.ts @@ -1986,178 +1986,4 @@ describe('wgslGenerator', () => { - fn:fn3: Value 'NaN' (abstractFloat) cannot be resolved due to WGSL's Finite Math Assumption (see: https://www.w3.org/TR/WGSL/#finite-math-assumption). This value might be a result of a comptime-evaluated operation.] `); }); - - describe('short-circuit evaluation', () => { - const state = { - counter: 0, - result: true, - }; - - const getTrackedBool = tgpu.comptime(() => { - state.counter++; - return state.result; - }); - - beforeEach(() => { - state.counter = 0; - state.result = true; - }); - - it('handles `||`', () => { - const f = () => { - 'use gpu'; - let res = -1; - // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test - if (true || getTrackedBool()) { - res = 1; - } - return res; - }; - - expect(tgpu.resolve([f])).toMatchInlineSnapshot(` - "fn f() -> i32 { - var res = -1; - { - res = 1i; - } - return res; - }" - `); - expect(state.counter).toBe(0); - }); - - it('handles `&&`', () => { - const f = () => { - 'use gpu'; - let res = -1; - // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test - if (false && getTrackedBool()) { - res = 1; - } - return res; - }; - - expect(tgpu.resolve([f])).toMatchInlineSnapshot(` - "fn f() -> i32 { - let res = -1; - return res; - }" - `); - expect(state.counter).toBe(0); - }); - - it('handles chained `||`', () => { - state.result = false; - - const f = () => { - 'use gpu'; - let res = -1; - // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test - if (getTrackedBool() || true || getTrackedBool() || getTrackedBool() || getTrackedBool()) { - res = 1; - } - return res; - }; - - expect(tgpu.resolve([f])).toMatchInlineSnapshot(` - "fn f() -> i32 { - var res = -1; - { - res = 1i; - } - return res; - }" - `); - expect(state.counter).toEqual(1); - }); - - it('handles chained `&&`', () => { - const f = () => { - 'use gpu'; - let res = -1; - // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test - if (getTrackedBool() && false && getTrackedBool() && getTrackedBool() && getTrackedBool()) { - res = 1; - } - return res; - }; - - expect(tgpu.resolve([f])).toMatchInlineSnapshot(` - "fn f() -> i32 { - let res = -1; - return res; - }" - `); - expect(state.counter).toBe(1); - }); - - it('handles mixed logical operators', () => { - const f = () => { - 'use gpu'; - let res = -1; - // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test - if (true || (getTrackedBool() && getTrackedBool())) { - res = 1; - } - return res; - }; - - expect(tgpu.resolve([f])).toMatchInlineSnapshot(` - "fn f() -> i32 { - var res = -1; - { - res = 1i; - } - return res; - }" - `); - expect(state.counter).toBe(0); - }); - - it('skips lhs if known at compile time', () => { - const f1 = tgpu.fn( - [d.bool], - d.i32, - )((b) => { - 'use gpu'; - let res = -1; - // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test - if (false || b) { - res = 1; - } - return res; - }); - - const f2 = tgpu.fn( - [d.bool], - d.i32, - )((b) => { - 'use gpu'; - let res = -1; - // oxlint-disable-next-line(no-constant-binary-expression) -- part of the test - if (true && b) { - res = 1; - } - return res; - }); - - expect(tgpu.resolve([f1, f2])).toMatchInlineSnapshot(` - "fn f1(b: bool) -> i32 { - var res = -1; - if (b) { - res = 1i; - } - return res; - } - - fn f2(b: bool) -> i32 { - var res = -1; - if (b) { - res = 1i; - } - return res; - }" - `); - }); - }); });