diff --git a/src/strands/ir_types.js b/src/strands/ir_types.js index 9f480d5c9d..5347ba81d9 100644 --- a/src/strands/ir_types.js +++ b/src/strands/ir_types.js @@ -130,6 +130,7 @@ export const OpCode = { Nary: { FUNCTION_CALL: 200, CONSTRUCTOR: 201, + TERNARY: 202, }, ControlFlow: { RETURN: 300, diff --git a/src/strands/strands_api.js b/src/strands/strands_api.js index ef4c0424c3..6239e5a314 100644 --- a/src/strands/strands_api.js +++ b/src/strands/strands_api.js @@ -15,6 +15,7 @@ import { import { strandsBuiltinFunctions } from './strands_builtins' import { StrandsConditional } from './strands_conditionals' import { StrandsFor } from './strands_for' +import { buildTernary } from './strands_ternary' import * as CFG from './ir_cfg' import * as DAG from './ir_dag'; import * as FES from './strands_FES' @@ -194,6 +195,10 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) { return new StrandsFor(strandsContext, initialCb, conditionCb, updateCb, bodyCb, initialVars).build(); }; augmentFn(fn, p5, 'strandsFor', p5.strandsFor); + p5.strandsTernary = function(condition, ifTrue, ifFalse) { + return buildTernary(strandsContext, condition, ifTrue, ifFalse); + }; + augmentFn(fn, p5, 'strandsTernary', p5.strandsTernary); p5.strandsEarlyReturn = function(value) { const { dag, cfg } = strandsContext; diff --git a/src/strands/strands_ternary.js b/src/strands/strands_ternary.js new file mode 100644 index 0000000000..dcd84522ce --- /dev/null +++ b/src/strands/strands_ternary.js @@ -0,0 +1,53 @@ +import * as DAG from './ir_dag'; +import * as CFG from './ir_cfg'; +import { NodeType, OpCode, BaseType } from './ir_types'; +import { createStrandsNode } from './strands_node'; +import * as FES from './strands_FES'; + +export function buildTernary(strandsContext, condition, ifTrue, ifFalse) { + const { dag, cfg, p5 } = strandsContext; + + // Ensure all inputs are StrandsNodes + const condNode = condition?.isStrandsNode ? condition : p5.strandsNode(condition); + const trueNode = ifTrue?.isStrandsNode ? ifTrue : p5.strandsNode(ifTrue); + const falseNode = ifFalse?.isStrandsNode ? ifFalse : p5.strandsNode(ifFalse); + + // Get type info for both nodes + let trueType = DAG.extractNodeTypeInfo(dag, trueNode.id); + let falseType = DAG.extractNodeTypeInfo(dag, falseNode.id); + + // Propagate type from the known branch to any ASSIGN_ON_USE branch + if (trueType.baseType === BaseType.ASSIGN_ON_USE && falseType.baseType !== BaseType.ASSIGN_ON_USE) { + DAG.propagateTypeToAssignOnUse(dag, trueNode.id, falseType.baseType, falseType.dimension); + trueType = DAG.extractNodeTypeInfo(dag, trueNode.id); + } else if (falseType.baseType === BaseType.ASSIGN_ON_USE && trueType.baseType !== BaseType.ASSIGN_ON_USE) { + DAG.propagateTypeToAssignOnUse(dag, falseNode.id, trueType.baseType, trueType.dimension); + falseType = DAG.extractNodeTypeInfo(dag, falseNode.id); + } + + // After ASSIGN_ON_USE propagation, if both types are known, they must match + if ( + trueType.baseType !== BaseType.ASSIGN_ON_USE && + falseType.baseType !== BaseType.ASSIGN_ON_USE && + (trueType.baseType !== falseType.baseType || trueType.dimension !== falseType.dimension) + ) { + FES.userError('type error', + 'The true and false branches of a ternary expression must have the same type. ' + + `Right now, the true branch is a ${trueType.baseType}${trueType.dimension}, and the false branch is a ${falseType.baseType}${falseType.dimension}.` + ); + } + + const resultType = trueType; + + const nodeData = DAG.createNodeData({ + nodeType: NodeType.OPERATION, + opCode: OpCode.Nary.TERNARY, + dependsOn: [condNode.id, trueNode.id, falseNode.id], + baseType: resultType.baseType, + dimension: resultType.dimension, + }); + + const id = DAG.getOrCreateNode(dag, nodeData); + CFG.recordInBasicBlock(cfg, cfg.currentBlock, id); + return createStrandsNode(id, resultType.dimension, strandsContext); +} diff --git a/src/strands/strands_transpiler.js b/src/strands/strands_transpiler.js index 836a177c6c..df0e770795 100644 --- a/src/strands/strands_transpiler.js +++ b/src/strands/strands_transpiler.js @@ -465,6 +465,20 @@ const ASTCallbacks = { }; node.arguments = [node.right]; }, + ConditionalExpression(node, _state, ancestors) { + if (ancestors.some(nodeIsUniform)) { return; } + // Transform condition ? consequent : alternate + // into __p5.strandsTernary(condition, consequent, alternate) + const test = node.test; + const consequent = node.consequent; + const alternate = node.alternate; + node.type = 'CallExpression'; + node.callee = { type: 'Identifier', name: '__p5.strandsTernary' }; + node.arguments = [test, consequent, alternate]; + delete node.test; + delete node.consequent; + delete node.alternate; + }, IfStatement(node, _state, ancestors) { if (ancestors.some(nodeIsUniform)) { return; } // Transform if statement into strandsIf() call diff --git a/src/webgl/strands_glslBackend.js b/src/webgl/strands_glslBackend.js index daf804a8e8..1004487a66 100644 --- a/src/webgl/strands_glslBackend.js +++ b/src/webgl/strands_glslBackend.js @@ -289,6 +289,13 @@ export const glslBackend = { const functionArgs = node.dependsOn.map(arg =>this.generateExpression(generationContext, dag, arg)); return `${node.identifier}(${functionArgs.join(', ')})`; } + if (node.opCode === OpCode.Nary.TERNARY) { + const [condID, trueID, falseID] = node.dependsOn; + const cond = this.generateExpression(generationContext, dag, condID); + const trueExpr = this.generateExpression(generationContext, dag, trueID); + const falseExpr = this.generateExpression(generationContext, dag, falseID); + return `(${cond} ? ${trueExpr} : ${falseExpr})`; + } if (node.opCode === OpCode.Binary.MEMBER_ACCESS) { const [lID, rID] = node.dependsOn; const lName = this.generateExpression(generationContext, dag, lID); diff --git a/src/webgpu/strands_wgslBackend.js b/src/webgpu/strands_wgslBackend.js index 4394210414..81369446fc 100644 --- a/src/webgpu/strands_wgslBackend.js +++ b/src/webgpu/strands_wgslBackend.js @@ -396,6 +396,13 @@ export const wgslBackend = { const deps = node.dependsOn.map((dep) => this.generateExpression(generationContext, dag, dep)); return `${T}(${deps.join(', ')})`; } + if (node.opCode === OpCode.Nary.TERNARY) { + const [condID, trueID, falseID] = node.dependsOn; + const cond = this.generateExpression(generationContext, dag, condID); + const trueExpr = this.generateExpression(generationContext, dag, trueID); + const falseExpr = this.generateExpression(generationContext, dag, falseID); + return `select(${falseExpr}, ${trueExpr}, ${cond})`; + } if (node.opCode === OpCode.Nary.FUNCTION_CALL) { // Convert mod() function calls to % operator in WGSL if (node.identifier === 'mod' && node.dependsOn.length === 2) { diff --git a/test/unit/webgl/p5.Shader.js b/test/unit/webgl/p5.Shader.js index 2556c1d25d..c2a16b21e7 100644 --- a/test/unit/webgl/p5.Shader.js +++ b/test/unit/webgl/p5.Shader.js @@ -1204,6 +1204,55 @@ test('returns numbers for builtin globals outside hooks and a strandNode when ca }); }); + suite('ternary expressions', () => { + test('ternary changes color based on left/right side of canvas', () => { + myp5.createCanvas(50, 25, myp5.WEBGL); + const testShader = myp5.baseMaterialShader().modify(() => { + myp5.getPixelInputs(inputs => { + inputs.color = inputs.texCoord.x > 0.5 ? [1, 0, 0, 1] : [0, 0, 1, 1]; + return inputs; + }); + }, { myp5 }); + myp5.noStroke(); + myp5.shader(testShader); + myp5.plane(myp5.width, myp5.height); + + const leftPixel = myp5.get(12, 12); + assert.approximately(leftPixel[0], 0, 5); + assert.approximately(leftPixel[1], 0, 5); + assert.approximately(leftPixel[2], 255, 5); + + const rightPixel = myp5.get(37, 12); + assert.approximately(rightPixel[0], 255, 5); + assert.approximately(rightPixel[1], 0, 5); + assert.approximately(rightPixel[2], 0, 5); + }); + + test('ternary with scalar values', () => { + myp5.createCanvas(50, 25, myp5.WEBGL); + const testShader = myp5.baseMaterialShader().modify(() => { + myp5.getPixelInputs(inputs => { + const brightness = inputs.texCoord.x > 0.5 ? 1.0 : 0.0; + inputs.color = [brightness, brightness, brightness, 1]; + return inputs; + }); + }, { myp5 }); + myp5.noStroke(); + myp5.shader(testShader); + myp5.plane(myp5.width, myp5.height); + + const leftPixel = myp5.get(12, 12); + assert.approximately(leftPixel[0], 0, 5); + assert.approximately(leftPixel[1], 0, 5); + assert.approximately(leftPixel[2], 0, 5); + + const rightPixel = myp5.get(37, 12); + assert.approximately(rightPixel[0], 255, 5); + assert.approximately(rightPixel[1], 255, 5); + assert.approximately(rightPixel[2], 255, 5); + }); + }); + suite('for loop statements', () => { test('handle simple for loop with known iteration count', () => { myp5.createCanvas(50, 50, myp5.WEBGL); @@ -2183,5 +2232,26 @@ test('returns numbers for builtin globals outside hooks and a strandNode when ca assert.include(errMsg, 'Expected properties'); assert.include(errMsg, 'Received properties'); }); + + test('ternary with mismatched branch types shows both types in error', () => { + myp5.createCanvas(50, 50, myp5.WEBGL); + + try { + myp5.baseMaterialShader().modify(() => { + myp5.getPixelInputs(inputs => { + // float1 vs float4 - type mismatch + const val = inputs.texCoord.x > 0.5 ? myp5.float(1.0) : [1, 0, 0, 1]; + inputs.color = [val, val, val, 1]; + return inputs; + }); + }, { myp5 }); + } catch (e) { /* expected */ } + + assert.isAbove(mockUserError.mock.calls.length, 0, 'FES.userError should have been called'); + const errMsg = mockUserError.mock.calls[0][1]; + assert.include(errMsg, 'ternary'); + assert.include(errMsg, 'float1'); + assert.include(errMsg, 'float4'); + }); }); }); diff --git a/test/unit/webgpu/p5.Shader.js b/test/unit/webgpu/p5.Shader.js index ba3bcd6bdc..7274ec84bf 100644 --- a/test/unit/webgpu/p5.Shader.js +++ b/test/unit/webgpu/p5.Shader.js @@ -490,7 +490,6 @@ suite('WebGPU p5.Shader', function() { return [0.4, 0, 0, 1]; }); }, { myp5 }); - console.log(testShader.fragSrc()) myp5.background(255, 255, 255); myp5.filter(testShader); @@ -502,6 +501,55 @@ suite('WebGPU p5.Shader', function() { }); }); + suite('ternary expressions', () => { + test('ternary changes color based on left/right side of canvas', async () => { + await myp5.createCanvas(50, 25, myp5.WEBGPU); + const testShader = myp5.baseMaterialShader().modify(() => { + myp5.getPixelInputs(inputs => { + inputs.color = inputs.texCoord.x > 0.5 ? [1, 0, 0, 1] : [0, 0, 1, 1]; + return inputs; + }); + }, { myp5 }); + myp5.noStroke(); + myp5.shader(testShader); + myp5.plane(myp5.width, myp5.height); + + const leftPixel = await myp5.get(12, 12); + assert.approximately(leftPixel[0], 0, 5); + assert.approximately(leftPixel[1], 0, 5); + assert.approximately(leftPixel[2], 255, 5); + + const rightPixel = await myp5.get(37, 12); + assert.approximately(rightPixel[0], 255, 5); + assert.approximately(rightPixel[1], 0, 5); + assert.approximately(rightPixel[2], 0, 5); + }); + + test('ternary with scalar values', async () => { + await myp5.createCanvas(50, 25, myp5.WEBGPU); + const testShader = myp5.baseMaterialShader().modify(() => { + myp5.getPixelInputs(inputs => { + const brightness = inputs.texCoord.x > 0.5 ? 1.0 : 0.0; + inputs.color = [brightness, brightness, brightness, 1]; + return inputs; + }); + }, { myp5 }); + myp5.noStroke(); + myp5.shader(testShader); + myp5.plane(myp5.width, myp5.height); + + const leftPixel = await myp5.get(12, 12); + assert.approximately(leftPixel[0], 0, 5); + assert.approximately(leftPixel[1], 0, 5); + assert.approximately(leftPixel[2], 0, 5); + + const rightPixel = await myp5.get(37, 12); + assert.approximately(rightPixel[0], 255, 5); + assert.approximately(rightPixel[1], 255, 5); + assert.approximately(rightPixel[2], 255, 5); + }); + }); + suite('for loop statements', () => { test('handle simple for loop with known iteration count', async () => { await myp5.createCanvas(50, 50, myp5.WEBGPU);