Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 62 additions & 16 deletions packages/typegpu/src/tgsl/wgslGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import { $gpuCallable, $internal, $providing, isMarkedInternal } from '../shared
import { safeStringify } from '../shared/stringify.ts';
import { pow } from '../std/numeric.ts';
import { add, div, mul, neg, sub } from '../std/operators.ts';
import { eq, ne, lt, le, gt, ge } from '../std/boolean.ts';

import {
isGPUCallable,
isKnownAtComptime,
Expand Down Expand Up @@ -85,6 +87,14 @@ const parenthesizedOps = [
];

const binaryLogicalOps = ['&&', '||', '==', '!=', '===', '!==', '<', '<=', '>', '>='];
const binaryRelationalOpToStdMap: Record<string, string> = {
'===': eq.toString(),
'!==': ne.toString(),
'<': lt.toString(),
'<=': le.toString(),
'>': gt.toString(),
'>=': ge.toString(),
};

const bitShiftOps: string[] = ['<<', '>>', '<<=', '>>='];

Expand Down Expand Up @@ -383,24 +393,26 @@ ${this.ctx.pre}}`;
throw new Error('Please use the !== operator instead of !=');
}

if (op === '===' && isKnownAtComptime(lhsExpr) && isKnownAtComptime(rhsExpr)) {
return snip(lhsExpr.value === rhsExpr.value, bool, 'constant', false);
}

if (op === '!==' && isKnownAtComptime(lhsExpr) && isKnownAtComptime(rhsExpr)) {
return snip(lhsExpr.value !== rhsExpr.value, bool, 'constant', false);
}

if (
(op === '<' || op === '<=' || op === '>' || op === '>=') &&
isKnownAtComptime(lhsExpr) &&
isKnownAtComptime(rhsExpr)
) {
const stdBinaryRelationalOp = binaryRelationalOpToStdMap[op];
if (stdBinaryRelationalOp && isKnownAtComptime(lhsExpr) && isKnownAtComptime(rhsExpr)) {
const left = lhsExpr.value;
const right = rhsExpr.value;

switch (op) {
case '===':
return snip(left === right, bool, 'constant', false);
case '!==':
return snip(left !== right, bool, 'constant', false);
}

if (typeof left !== 'number' || typeof right !== 'number') {
const bothVectors = wgsl.isVec(lhsExpr.dataType) && wgsl.isVec(rhsExpr.dataType);
throw new WgslTypeError(
`Inequality comparison '${op}' requires numeric operands, got '${typeof left}' and '${typeof right}'`,
`Comparison '${op}' requires numeric operands.${
bothVectors
? ` For component-wise comparison, use 'std.${stdBinaryRelationalOp}''.`
: ''
}`,
);
}

Expand Down Expand Up @@ -469,6 +481,33 @@ ${this.ctx.pre}}`;
}
}

if (stdBinaryRelationalOp) {
const equalityCheck = ['===', '!=='].includes(op);
const correctOperandTypes =
(wgsl.isNumericSchema(convLhs.dataType) && wgsl.isNumericSchema(convRhs.dataType)) ||
(equalityCheck && wgsl.isBool(convLhs.dataType) && wgsl.isBool(convRhs.dataType));

if (!correctOperandTypes) {
const bothVectors = wgsl.isVec(convLhs.dataType) && wgsl.isVec(convRhs.dataType);
throw new WgslTypeError(
`Comparison '${op}' requires numeric${equalityCheck ? ' or boolean' : ''} operands. Got '${String(convLhs.dataType)}' and '${String(convRhs.dataType)}'.${
bothVectors
? ` For component-wise comparison, use 'std.${stdBinaryRelationalOp}'.`
: ''
}`,
);
}
}

if (
(op === '&&' || op === '||') &&
!(wgsl.isBool(convLhs.dataType) && wgsl.isBool(convRhs.dataType))
) {
throw new WgslTypeError(
`Logical expression '${op}' requires boolean operands. Got '${String(convLhs.dataType)}' and '${String(convRhs.dataType)}'.`,
);
}

return snip(
parenthesizedOps.includes(op)
? `(${lhsStr} ${OP_MAP[op] ?? op} ${rhsStr})`
Expand Down Expand Up @@ -859,6 +898,13 @@ ${this.ctx.pre}}`;
if (isKnownAtComptime(test)) {
return test.value ? this._expression(consequentNode) : this._expression(alternativeNode);
} else {
const convertedTest = tryConvertSnippet(this.ctx, test, bool, false);
if (!convertedTest) {
throw new Error(
`Ternary operator '${stringifyNode(expression)}' is invalid. Cannot convert condition to bool.`,
);
}

const consequent = this._expression(consequentNode);
const alternative = this._expression(alternativeNode);
const [con, alt] =
Expand All @@ -871,11 +917,11 @@ ${this.ctx.pre}}`;
}

return snip(
stitch`select(${alt}, ${con}, ${test})`,
stitch`select(${alt}, ${con}, ${convertedTest})`,
con.dataType,
'runtime',
// this select has side-effects only if the condition has side-effects
test.possibleSideEffects,
convertedTest.possibleSideEffects,
);
}
}
Expand Down
Loading
Loading