Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/strands/ir_types.js
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ export const OpCode = {
Nary: {
FUNCTION_CALL: 200,
CONSTRUCTOR: 201,
TERNARY: 202,
},
ControlFlow: {
RETURN: 300,
Expand Down
5 changes: 5 additions & 0 deletions src/strands/strands_api.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
import { strandsBuiltinFunctions } from './strands_builtins'
import { StrandsConditional } from './strands_conditionals'
import { StrandsFor } from './strands_for'
import { buildTernary } from './strands_ternary'
import * as CFG from './ir_cfg'
import * as DAG from './ir_dag';
import * as FES from './strands_FES'
Expand Down Expand Up @@ -194,6 +195,10 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) {
return new StrandsFor(strandsContext, initialCb, conditionCb, updateCb, bodyCb, initialVars).build();
};
augmentFn(fn, p5, 'strandsFor', p5.strandsFor);
p5.strandsTernary = function(condition, ifTrue, ifFalse) {
return buildTernary(strandsContext, condition, ifTrue, ifFalse);
};
augmentFn(fn, p5, 'strandsTernary', p5.strandsTernary);
p5.strandsEarlyReturn = function(value) {
const { dag, cfg } = strandsContext;

Expand Down
53 changes: 53 additions & 0 deletions src/strands/strands_ternary.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import * as DAG from './ir_dag';
import * as CFG from './ir_cfg';
import { NodeType, OpCode, BaseType } from './ir_types';
import { createStrandsNode } from './strands_node';
import * as FES from './strands_FES';

export function buildTernary(strandsContext, condition, ifTrue, ifFalse) {
const { dag, cfg, p5 } = strandsContext;

// Ensure all inputs are StrandsNodes
const condNode = condition?.isStrandsNode ? condition : p5.strandsNode(condition);
const trueNode = ifTrue?.isStrandsNode ? ifTrue : p5.strandsNode(ifTrue);
const falseNode = ifFalse?.isStrandsNode ? ifFalse : p5.strandsNode(ifFalse);

// Get type info for both nodes
let trueType = DAG.extractNodeTypeInfo(dag, trueNode.id);
let falseType = DAG.extractNodeTypeInfo(dag, falseNode.id);

// Propagate type from the known branch to any ASSIGN_ON_USE branch
if (trueType.baseType === BaseType.ASSIGN_ON_USE && falseType.baseType !== BaseType.ASSIGN_ON_USE) {
DAG.propagateTypeToAssignOnUse(dag, trueNode.id, falseType.baseType, falseType.dimension);
trueType = DAG.extractNodeTypeInfo(dag, trueNode.id);
} else if (falseType.baseType === BaseType.ASSIGN_ON_USE && trueType.baseType !== BaseType.ASSIGN_ON_USE) {
DAG.propagateTypeToAssignOnUse(dag, falseNode.id, trueType.baseType, trueType.dimension);
falseType = DAG.extractNodeTypeInfo(dag, falseNode.id);
}

// After ASSIGN_ON_USE propagation, if both types are known, they must match
if (
trueType.baseType !== BaseType.ASSIGN_ON_USE &&
falseType.baseType !== BaseType.ASSIGN_ON_USE &&
(trueType.baseType !== falseType.baseType || trueType.dimension !== falseType.dimension)
) {
FES.userError('type error',
'The true and false branches of a ternary expression must have the same type. ' +
`Right now, the true branch is a ${trueType.baseType}${trueType.dimension}, and the false branch is a ${falseType.baseType}${falseType.dimension}.`
);
}

const resultType = trueType;

const nodeData = DAG.createNodeData({
nodeType: NodeType.OPERATION,
opCode: OpCode.Nary.TERNARY,
dependsOn: [condNode.id, trueNode.id, falseNode.id],
baseType: resultType.baseType,
dimension: resultType.dimension,
});

const id = DAG.getOrCreateNode(dag, nodeData);
CFG.recordInBasicBlock(cfg, cfg.currentBlock, id);
return createStrandsNode(id, resultType.dimension, strandsContext);
}
14 changes: 14 additions & 0 deletions src/strands/strands_transpiler.js
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,20 @@ const ASTCallbacks = {
};
node.arguments = [node.right];
},
ConditionalExpression(node, _state, ancestors) {
if (ancestors.some(nodeIsUniform)) { return; }
// Transform condition ? consequent : alternate
// into __p5.strandsTernary(condition, consequent, alternate)
const test = node.test;
const consequent = node.consequent;
const alternate = node.alternate;
node.type = 'CallExpression';
node.callee = { type: 'Identifier', name: '__p5.strandsTernary' };
node.arguments = [test, consequent, alternate];
delete node.test;
delete node.consequent;
delete node.alternate;
},
IfStatement(node, _state, ancestors) {
if (ancestors.some(nodeIsUniform)) { return; }
// Transform if statement into strandsIf() call
Expand Down
7 changes: 7 additions & 0 deletions src/webgl/strands_glslBackend.js
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,13 @@ export const glslBackend = {
const functionArgs = node.dependsOn.map(arg =>this.generateExpression(generationContext, dag, arg));
return `${node.identifier}(${functionArgs.join(', ')})`;
}
if (node.opCode === OpCode.Nary.TERNARY) {
const [condID, trueID, falseID] = node.dependsOn;
const cond = this.generateExpression(generationContext, dag, condID);
const trueExpr = this.generateExpression(generationContext, dag, trueID);
const falseExpr = this.generateExpression(generationContext, dag, falseID);
return `(${cond} ? ${trueExpr} : ${falseExpr})`;
}
if (node.opCode === OpCode.Binary.MEMBER_ACCESS) {
const [lID, rID] = node.dependsOn;
const lName = this.generateExpression(generationContext, dag, lID);
Expand Down
7 changes: 7 additions & 0 deletions src/webgpu/strands_wgslBackend.js
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,13 @@ export const wgslBackend = {
const deps = node.dependsOn.map((dep) => this.generateExpression(generationContext, dag, dep));
return `${T}(${deps.join(', ')})`;
}
if (node.opCode === OpCode.Nary.TERNARY) {
const [condID, trueID, falseID] = node.dependsOn;
const cond = this.generateExpression(generationContext, dag, condID);
const trueExpr = this.generateExpression(generationContext, dag, trueID);
const falseExpr = this.generateExpression(generationContext, dag, falseID);
return `select(${falseExpr}, ${trueExpr}, ${cond})`;
}
if (node.opCode === OpCode.Nary.FUNCTION_CALL) {
// Convert mod() function calls to % operator in WGSL
if (node.identifier === 'mod' && node.dependsOn.length === 2) {
Expand Down
70 changes: 70 additions & 0 deletions test/unit/webgl/p5.Shader.js
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,55 @@ test('returns numbers for builtin globals outside hooks and a strandNode when ca
});
});

suite('ternary expressions', () => {
test('ternary changes color based on left/right side of canvas', () => {
myp5.createCanvas(50, 25, myp5.WEBGL);
const testShader = myp5.baseMaterialShader().modify(() => {
myp5.getPixelInputs(inputs => {
inputs.color = inputs.texCoord.x > 0.5 ? [1, 0, 0, 1] : [0, 0, 1, 1];
return inputs;
});
}, { myp5 });
myp5.noStroke();
myp5.shader(testShader);
myp5.plane(myp5.width, myp5.height);

const leftPixel = myp5.get(12, 12);
assert.approximately(leftPixel[0], 0, 5);
assert.approximately(leftPixel[1], 0, 5);
assert.approximately(leftPixel[2], 255, 5);

const rightPixel = myp5.get(37, 12);
assert.approximately(rightPixel[0], 255, 5);
assert.approximately(rightPixel[1], 0, 5);
assert.approximately(rightPixel[2], 0, 5);
});

test('ternary with scalar values', () => {
myp5.createCanvas(50, 25, myp5.WEBGL);
const testShader = myp5.baseMaterialShader().modify(() => {
myp5.getPixelInputs(inputs => {
const brightness = inputs.texCoord.x > 0.5 ? 1.0 : 0.0;
inputs.color = [brightness, brightness, brightness, 1];
return inputs;
});
}, { myp5 });
myp5.noStroke();
myp5.shader(testShader);
myp5.plane(myp5.width, myp5.height);

const leftPixel = myp5.get(12, 12);
assert.approximately(leftPixel[0], 0, 5);
assert.approximately(leftPixel[1], 0, 5);
assert.approximately(leftPixel[2], 0, 5);

const rightPixel = myp5.get(37, 12);
assert.approximately(rightPixel[0], 255, 5);
assert.approximately(rightPixel[1], 255, 5);
assert.approximately(rightPixel[2], 255, 5);
});
});

suite('for loop statements', () => {
test('handle simple for loop with known iteration count', () => {
myp5.createCanvas(50, 50, myp5.WEBGL);
Expand Down Expand Up @@ -2183,5 +2232,26 @@ test('returns numbers for builtin globals outside hooks and a strandNode when ca
assert.include(errMsg, 'Expected properties');
assert.include(errMsg, 'Received properties');
});

test('ternary with mismatched branch types shows both types in error', () => {
myp5.createCanvas(50, 50, myp5.WEBGL);

try {
myp5.baseMaterialShader().modify(() => {
myp5.getPixelInputs(inputs => {
// float1 vs float4 - type mismatch
const val = inputs.texCoord.x > 0.5 ? myp5.float(1.0) : [1, 0, 0, 1];
inputs.color = [val, val, val, 1];
return inputs;
});
}, { myp5 });
} catch (e) { /* expected */ }

assert.isAbove(mockUserError.mock.calls.length, 0, 'FES.userError should have been called');
const errMsg = mockUserError.mock.calls[0][1];
assert.include(errMsg, 'ternary');
assert.include(errMsg, 'float1');
assert.include(errMsg, 'float4');
});
});
});
50 changes: 49 additions & 1 deletion test/unit/webgpu/p5.Shader.js
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ suite('WebGPU p5.Shader', function() {
return [0.4, 0, 0, 1];
});
}, { myp5 });
console.log(testShader.fragSrc())

myp5.background(255, 255, 255);
myp5.filter(testShader);
Expand All @@ -502,6 +501,55 @@ suite('WebGPU p5.Shader', function() {
});
});

suite('ternary expressions', () => {
test('ternary changes color based on left/right side of canvas', async () => {
await myp5.createCanvas(50, 25, myp5.WEBGPU);
const testShader = myp5.baseMaterialShader().modify(() => {
myp5.getPixelInputs(inputs => {
inputs.color = inputs.texCoord.x > 0.5 ? [1, 0, 0, 1] : [0, 0, 1, 1];
return inputs;
});
}, { myp5 });
myp5.noStroke();
myp5.shader(testShader);
myp5.plane(myp5.width, myp5.height);

const leftPixel = await myp5.get(12, 12);
assert.approximately(leftPixel[0], 0, 5);
assert.approximately(leftPixel[1], 0, 5);
assert.approximately(leftPixel[2], 255, 5);

const rightPixel = await myp5.get(37, 12);
assert.approximately(rightPixel[0], 255, 5);
assert.approximately(rightPixel[1], 0, 5);
assert.approximately(rightPixel[2], 0, 5);
});

test('ternary with scalar values', async () => {
await myp5.createCanvas(50, 25, myp5.WEBGPU);
const testShader = myp5.baseMaterialShader().modify(() => {
myp5.getPixelInputs(inputs => {
const brightness = inputs.texCoord.x > 0.5 ? 1.0 : 0.0;
inputs.color = [brightness, brightness, brightness, 1];
return inputs;
});
}, { myp5 });
myp5.noStroke();
myp5.shader(testShader);
myp5.plane(myp5.width, myp5.height);

const leftPixel = await myp5.get(12, 12);
assert.approximately(leftPixel[0], 0, 5);
assert.approximately(leftPixel[1], 0, 5);
assert.approximately(leftPixel[2], 0, 5);

const rightPixel = await myp5.get(37, 12);
assert.approximately(rightPixel[0], 255, 5);
assert.approximately(rightPixel[1], 255, 5);
assert.approximately(rightPixel[2], 255, 5);
});
});

suite('for loop statements', () => {
test('handle simple for loop with known iteration count', async () => {
await myp5.createCanvas(50, 50, myp5.WEBGPU);
Expand Down
Loading