@@ -101,6 +101,7 @@ struct metal_print_context
101101 , paramsStr(ralloc_strdup(buffer, " " ))
102102 , writingParams(false )
103103 , matrixCastsDone(false )
104+ , matrixConstructorsDone(false )
104105 , shadowSamplerDone(false )
105106 , textureCounter(0 )
106107 , attributeCounter(0 )
@@ -118,6 +119,7 @@ struct metal_print_context
118119 string_buffer paramsStr;
119120 bool writingParams;
120121 bool matrixCastsDone;
122+ bool matrixConstructorsDone;
121123 bool shadowSamplerDone;
122124 int textureCounter;
123125 int attributeCounter;
@@ -940,21 +942,57 @@ void ir_print_metal_visitor::visit(ir_expression *ir)
940942 bool op0cast = ir->operands [0 ] && is_different_precision (arg_prec, ir->operands [0 ]->get_precision ());
941943 bool op1cast = ir->operands [1 ] && is_different_precision (arg_prec, ir->operands [1 ]->get_precision ());
942944 bool op2cast = ir->operands [2 ] && is_different_precision (arg_prec, ir->operands [2 ]->get_precision ());
945+ const bool op0matrix = ir->operands [0 ] && ir->operands [0 ]->type ->is_matrix ();
946+ const bool op1matrix = ir->operands [1 ] && ir->operands [1 ]->type ->is_matrix ();
947+ bool op0castTo1 = false ;
948+ bool op1castTo0 = false ;
943949
944950 // Metal does not support matrix precision casts, so when any of the arguments is a matrix,
945951 // take precision from it. This isn't fully robust now, but oh well.
946- if (op0cast && ir-> operands [ 0 ]-> type -> is_matrix () && !op1cast)
952+ if (op0cast && op0matrix && !op1cast)
947953 {
948954 op0cast = false ;
949955 arg_prec = ir->operands [0 ]->get_precision ();
950956 op1cast = ir->operands [1 ] && is_different_precision (arg_prec, ir->operands [1 ]->get_precision ());
951957 }
952- if (op1cast && ir-> operands [ 1 ]-> type -> is_matrix () && !op0cast)
958+ if (op1cast && op1matrix && !op0cast)
953959 {
954960 op1cast = false ;
955961 arg_prec = ir->operands [1 ]->get_precision ();
956962 op0cast = ir->operands [0 ] && is_different_precision (arg_prec, ir->operands [0 ]->get_precision ());
957963 }
964+
965+ // Metal does not have matrix+scalar and matrix-scalar operations; we need to create matrices
966+ // out of the non-matrix argument.
967+ if (ir->operation == ir_binop_add || ir->operation == ir_binop_sub)
968+ {
969+ if (op0matrix && !op1matrix)
970+ {
971+ op1cast = true ;
972+ op1castTo0 = true ;
973+ }
974+ if (op1matrix && !op0matrix)
975+ {
976+ op0cast = true ;
977+ op0castTo1 = true ;
978+ }
979+ if (op1castTo0 || op0castTo1)
980+ {
981+ if (!ctx.matrixConstructorsDone )
982+ {
983+ ctx.prefixStr .asprintf_append (
984+ " inline float4x4 _xlinit_float4x4(float v) { return float4x4(float4(v), float4(v), float4(v), float4(v)); }\n "
985+ " inline float3x3 _xlinit_float3x3(float v) { return float3x3(float3(v), float3(v), float3(v)); }\n "
986+ " inline float2x2 _xlinit_float2x2(float v) { return float2x2(float2(v), float2(v)); }\n "
987+ " inline half4x4 _xlinit_half4x4(half v) { return half4x4(half4(v), half4(v), half4(v), half4(v)); }\n "
988+ " inline half3x3 _xlinit_half3x3(half v) { return half3x3(half3(v), half3(v), half3(v)); }\n "
989+ " inline half2x2 _xlinit_half2x2(half v) { return half2x2(half2(v), half2(v)); }\n "
990+ );
991+ ctx.matrixConstructorsDone = true ;
992+ }
993+ }
994+ }
995+
958996
959997 const bool rescast = is_different_precision (arg_prec, res_prec) && !ir->type ->is_boolean ();
960998 if (rescast)
@@ -1000,6 +1038,7 @@ void ir_print_metal_visitor::visit(ir_expression *ir)
10001038 }
10011039 else if (is_binop_func_like (ir->operation , ir->type ))
10021040 {
1041+ // binary operation that must be printed like a function, "foo(a,b)"
10031042 if (ir->operation == ir_binop_mod)
10041043 {
10051044 buffer.asprintf_append (" (" );
@@ -1025,23 +1064,58 @@ void ir_print_metal_visitor::visit(ir_expression *ir)
10251064 if (ir->operation == ir_binop_mod)
10261065 buffer.asprintf_append (" ))" );
10271066 }
1067+ else if (ir->get_num_operands () == 2 && ir->operation == ir_binop_div && op0matrix && !op1matrix)
1068+ {
1069+ // "matrix/scalar" - Metal does not have it, so print multiply by inverse instead
1070+ buffer.asprintf_append (" (" );
1071+ ir->operands [0 ]->accept (this );
1072+ const bool halfCast = (arg_prec == glsl_precision_medium || arg_prec == glsl_precision_low);
1073+ buffer.asprintf_append (halfCast ? " * (1.0h/half(" : " * (1.0/(" );
1074+ ir->operands [1 ]->accept (this );
1075+ buffer.asprintf_append (" )))" );
1076+ }
10281077 else if (ir->get_num_operands () == 2 )
10291078 {
1079+ // regular binary operator
10301080 buffer.asprintf_append (" (" );
10311081 if (ir->operands [0 ])
10321082 {
1033- if (op0cast)
1083+ if (op0castTo1)
1084+ {
1085+ buffer.asprintf_append (" _xlinit_" );
1086+ print_type_precision (buffer, ir->operands [1 ]->type , arg_prec, false );
1087+ buffer.asprintf_append (" (" );
1088+ }
1089+ else if (op0cast)
1090+ {
10341091 print_cast (buffer, arg_prec, ir->operands [0 ]);
1092+ }
10351093 ir->operands [0 ]->accept (this );
1094+ if (op0castTo1)
1095+ {
1096+ buffer.asprintf_append (" )" );
1097+ }
10361098 }
10371099
10381100 buffer.asprintf_append (" %s " , operator_glsl_strs[ir->operation ]);
10391101
10401102 if (ir->operands [1 ])
10411103 {
1042- if (op1cast)
1104+ if (op1castTo0)
1105+ {
1106+ buffer.asprintf_append (" _xlinit_" );
1107+ print_type_precision (buffer, ir->operands [0 ]->type , arg_prec, false );
1108+ buffer.asprintf_append (" (" );
1109+ }
1110+ else if (op1cast)
1111+ {
10431112 print_cast (buffer, arg_prec, ir->operands [1 ]);
1113+ }
10441114 ir->operands [1 ]->accept (this );
1115+ if (op1castTo0)
1116+ {
1117+ buffer.asprintf_append (" )" );
1118+ }
10451119 }
10461120 buffer.asprintf_append (" )" );
10471121 }
0 commit comments