Skip to content

Commit dbd39c1

Browse files
committed
use optimized network evaluation shader code
1 parent 2011cd8 commit dbd39c1

File tree

4 files changed

+87
-140
lines changed

4 files changed

+87
-140
lines changed

Editor/MobileNeRFImporter.cs

Lines changed: 19 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
using System.Text;
55
using System.Text.RegularExpressions;
66
using System.Threading.Tasks;
7-
using Unity.Collections;
87
using UnityEditor;
98
using UnityEngine;
109
using static WebRequestAsyncUtility;
@@ -74,7 +73,7 @@ public static void ImportAssetsFromDisk() {
7473
}
7574
}
7675

77-
// ask for axis siwtch behaviour
76+
// ask for axis switch behaviour
7877
if (EditorUtility.DisplayDialog(SwitchAxisTitle, SwitchAxisMsg, Switch, NoSwitch)) {
7978
SwizzleAxis = true;
8079
} else {
@@ -165,7 +164,7 @@ public static void DownloadStumpAssets() {
165164

166165
/// <summary>
167166
/// Some scenes require switching the y and z axis in the shader.
168-
/// For custom scenes this tracks whether which one should be used.
167+
/// For custom scenes this tracks, which one should be used.
169168
/// </summary>
170169
public static bool SwizzleAxis = false;
171170

@@ -193,12 +192,6 @@ private static string GetMLPAssetPath(string objName) {
193192
Directory.CreateDirectory(Path.GetDirectoryName(path));
194193
return path;
195194
}
196-
197-
private static string GetWeightsAssetPath(string objName, int i) {
198-
string path = $"{GetBasePath(objName)}/MLP/weightsTex{i}.asset";
199-
Directory.CreateDirectory(Path.GetDirectoryName(path));
200-
return path;
201-
}
202195

203196
private static string GetFeatureTextureAssetPath(string objName, int shapeNum, int featureNum) {
204197
string path = $"{GetBasePath(objName)}/PNGs/shape{shapeNum}.pngfeat{featureNum}.png";
@@ -297,13 +290,12 @@ private static async Task ImportDemoSceneAsync(MNeRFScene scene) {
297290

298291
/// <summary>
299292
/// Set specific import settings on OBJs/PNGs.
300-
/// Creates Weight Textures, Materials and Shader from MLP data.
293+
/// Creates Materials and Shader from MLP data.
301294
/// Creates a convenient prefab for the MobileNeRF object.
302295
/// </summary>
303296
private static void ProcessAssets(string objName) {
304297
Mlp mlp = GetMlp(objName);
305298
CreateShader(objName, mlp);
306-
CreateWeightTextures(objName, mlp);
307299
// PNGs are configured in PNGImportProcessor.cs
308300
ProcessOBJs(objName, mlp);
309301
CreatePrefab(objName, mlp);
@@ -499,14 +491,6 @@ private static void ProcessOBJs(string objName, Mlp mlp) {
499491
Material material = AssetDatabase.LoadAssetAtPath<Material>(materialAssetPath);
500492
material.shader = mobileNeRFShader;
501493

502-
// assign weight textures
503-
Texture2D weightsTexZero = AssetDatabase.LoadAssetAtPath<Texture2D>(GetWeightsAssetPath(objName, 0));
504-
Texture2D weightsTexOne = AssetDatabase.LoadAssetAtPath<Texture2D>(GetWeightsAssetPath(objName, 1));
505-
Texture2D weightsTexTwo = AssetDatabase.LoadAssetAtPath<Texture2D>(GetWeightsAssetPath(objName, 2));
506-
material.SetTexture("weightsZero", weightsTexZero);
507-
material.SetTexture("weightsOne", weightsTexOne);
508-
material.SetTexture("weightsTwo", weightsTexTwo);
509-
510494
// assign feature textures
511495
string feat0AssetPath = GetFeatureTextureAssetPath(objName, i, 0);
512496
string feat1AssetPath = GetFeatureTextureAssetPath(objName, i, 1);
@@ -528,25 +512,26 @@ private static void ProcessOBJs(string objName, Mlp mlp) {
528512
private static void CreateShader(string objName, Mlp mlp) {
529513
int width = mlp._0Bias.Length;
530514

531-
StringBuilder biasListZero = toBiasList(mlp._0Bias);
532-
StringBuilder biasListOne = toBiasList(mlp._1Bias);
533-
StringBuilder biasListTwo = toBiasList(mlp._2Bias);
534-
535-
int channelsZero = mlp._0Weights.Length;
536-
int channelsOne = mlp._0Bias.Length;
537-
int channelsTwo = mlp._1Bias.Length;
538-
int channelsThree = mlp._2Bias.Length;
515+
StringBuilder biasListZero = toConstructorList(mlp._0Bias);
516+
StringBuilder biasListOne = toConstructorList(mlp._1Bias);
517+
StringBuilder biasListTwo = toConstructorList(mlp._2Bias);
539518

540519
string shaderSource = ViewDependenceNetworkShader.Template;
541520
shaderSource = new Regex("OBJECT_NAME" ).Replace(shaderSource, $"{objName}");
542-
shaderSource = new Regex("NUM_CHANNELS_ZERO" ).Replace(shaderSource, $"{channelsZero}");
543-
shaderSource = new Regex("NUM_CHANNELS_ONE" ).Replace(shaderSource, $"{channelsOne}");
544-
shaderSource = new Regex("NUM_CHANNELS_TWO" ).Replace(shaderSource, $"{channelsTwo}");
545-
shaderSource = new Regex("NUM_CHANNELS_THREE").Replace(shaderSource, $"{channelsThree}");
546521
shaderSource = new Regex("BIAS_LIST_ZERO" ).Replace(shaderSource, $"{biasListZero}");
547522
shaderSource = new Regex("BIAS_LIST_ONE" ).Replace(shaderSource, $"{biasListOne}");
548523
shaderSource = new Regex("BIAS_LIST_TWO" ).Replace(shaderSource, $"{biasListTwo}");
549524

525+
for (int i = 0; i < mlp._0Weights.Length; i++) {
526+
shaderSource = new Regex($"__W0_{i}__").Replace(shaderSource, $"{toConstructorList(mlp._0Weights[i])}");
527+
}
528+
for (int i = 0; i < mlp._1Weights.Length; i++) {
529+
shaderSource = new Regex($"__W1_{i}__").Replace(shaderSource, $"{toConstructorList(mlp._1Weights[i])}");
530+
}
531+
for (int i = 0; i < mlp._2Weights.Length; i++) {
532+
shaderSource = new Regex($"__W2_{i}__").Replace(shaderSource, $"{toConstructorList(mlp._2Weights[i])}");
533+
}
534+
550535
// hack way to flip axes depending on scene
551536
string axisSwizzle = MNeRFSceneExtensions.ToEnum(objName).GetAxisSwizzleString();
552537
shaderSource = new Regex("AXIS_SWIZZLE" ).Replace(shaderSource, $"{axisSwizzle}");
@@ -556,52 +541,12 @@ private static void CreateShader(string objName, Mlp mlp) {
556541
AssetDatabase.Refresh();
557542
}
558543

559-
private static void CreateWeightTextures(string objName, Mlp mlp) {
560-
Texture2D weightsTexZero = createFloatTextureFromData(mlp._0Weights);
561-
Texture2D weightsTexOne = createFloatTextureFromData(mlp._1Weights);
562-
Texture2D weightsTexTwo = createFloatTextureFromData(mlp._2Weights);
563-
AssetDatabase.CreateAsset(weightsTexZero, GetWeightsAssetPath(objName, 0));
564-
AssetDatabase.CreateAsset(weightsTexOne, GetWeightsAssetPath(objName, 1));
565-
AssetDatabase.CreateAsset(weightsTexTwo, GetWeightsAssetPath(objName, 2));
566-
AssetDatabase.SaveAssets();
567-
}
568-
569-
/// <summary>
570-
/// Creates a float32 texture from an array of floats.
571-
/// </summary>
572-
private static Texture2D createFloatTextureFromData(double[][] weights) {
573-
int width = weights.Length;
574-
int height = weights[0].Length;
575-
576-
Texture2D texture = new Texture2D(width, height, TextureFormat.RFloat, mipChain: false, linear: true);
577-
texture.filterMode = FilterMode.Point;
578-
texture.wrapMode = TextureWrapMode.Clamp;
579-
NativeArray<float> textureData = texture.GetRawTextureData<float>();
580-
FillTexture(textureData, weights);
581-
texture.Apply();
582-
583-
return texture;
584-
}
585-
586-
private static void FillTexture(NativeArray<float> textureData, double[][] data) {
587-
int width = data.Length;
588-
int height = data[0].Length;
589-
590-
for (int co = 0; co < height; co++) {
591-
for (int ci = 0; ci < width; ci++) {
592-
int index = co * width + ci;
593-
double weight = data[ci][co];
594-
textureData[index] = (float)weight;
595-
}
596-
}
597-
}
598-
599-
private static StringBuilder toBiasList(double[] biases) {
544+
private static StringBuilder toConstructorList(double[] list) {
600545
System.Globalization.CultureInfo culture = System.Globalization.CultureInfo.InvariantCulture;
601-
int width = biases.Length;
546+
int width = list.Length;
602547
StringBuilder biasList = new StringBuilder(width * 12);
603548
for (int i = 0; i < width; i++) {
604-
double bias = biases[i];
549+
double bias = list[i];
605550
biasList.Append(bias.ToString("F7", culture));
606551
if (i + 1 < width) {
607552
biasList.Append(", ");

Editor/MobileNeRFScene.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public static string GetAxisSwizzleString(this MNeRFScene scene) {
4545
case MNeRFScene.Custom:
4646
// Based on user feedback for custom scenes
4747
if (MobileNeRFImporter.SwizzleAxis) {
48-
return "o.rayDirection.xz = -o.rayDirection.xz;" +
48+
return "o.rayDirection.xz = -o.rayDirection.xz;" + Environment.NewLine +
4949
"o.rayDirection.xyz = o.rayDirection.xzy;";
5050
} else {
5151
return "o.rayDirection.x = -o.rayDirection.x;";

Editor/ShaderTemplate.cs

Lines changed: 63 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@ public static class ViewDependenceNetworkShader {
33
Properties {
44
tDiffuse0x (""Diffuse Texture 0"", 2D) = ""white"" {}
55
tDiffuse1x (""Diffuse Texture 1"", 2D) = ""white"" {}
6-
weightsZero (""Weights Zero"", 2D) = ""white"" {}
7-
weightsOne (""Weights One"", 2D) = ""white"" {}
8-
weightsTwo (""Weights Two"", 2D) = ""white"" {}
96
}
107
118
CGINCLUDE
@@ -41,66 +38,68 @@ v2f vert(appdata v) {
4138
4239
sampler2D tDiffuse0x;
4340
sampler2D tDiffuse1x;
44-
sampler2D tDiffuse2x;
45-
46-
UNITY_DECLARE_TEX2D(weightsZero);
47-
UNITY_DECLARE_TEX2D(weightsOne);
48-
UNITY_DECLARE_TEX2D(weightsTwo);
4941
5042
half3 evaluateNetwork(fixed4 f0, fixed4 f1, fixed4 viewdir) {
51-
half intermediate_one[NUM_CHANNELS_ONE] = { BIAS_LIST_ZERO };
52-
int i = 0;
53-
int j = 0;
54-
55-
for (j = 0; j < NUM_CHANNELS_ZERO; ++j) {
56-
half input_value = 0.0;
57-
if (j < 4) {
58-
input_value =
59-
(j == 0) ? f0.r : (
60-
(j == 1) ? f0.g : (
61-
(j == 2) ? f0.b : f0.a));
62-
} else if (j < 8) {
63-
input_value =
64-
(j == 4) ? f1.r : (
65-
(j == 5) ? f1.g : (
66-
(j == 6) ? f1.b : f1.a));
67-
} else {
68-
input_value =
69-
(j == 8) ? viewdir.r : (
70-
(j == 9) ? viewdir.g : viewdir.b);
71-
}
72-
for (i = 0; i < NUM_CHANNELS_ONE; ++i) {
73-
intermediate_one[i] += input_value * weightsZero.Load(int3(j, i, 0)).x;
74-
}
75-
}
76-
77-
half intermediate_two[NUM_CHANNELS_TWO] = { BIAS_LIST_ONE };
78-
79-
for (j = 0; j < NUM_CHANNELS_ONE; ++j) {
80-
if (intermediate_one[j] <= 0.0) {
81-
continue;
82-
}
83-
for (i = 0; i < NUM_CHANNELS_TWO; ++i) {
84-
intermediate_two[i] += intermediate_one[j] * weightsOne.Load(int3(j, i, 0)).x;
85-
}
86-
}
87-
88-
half result[NUM_CHANNELS_THREE] = { BIAS_LIST_TWO };
89-
90-
for (j = 0; j < NUM_CHANNELS_TWO; ++j) {
91-
if (intermediate_two[j] <= 0.0) {
92-
continue;
93-
}
94-
for (i = 0; i < NUM_CHANNELS_THREE; ++i) {
95-
result[i] += intermediate_two[j] * weightsTwo.Load(int3(j, i, 0)).x;
96-
}
97-
}
98-
for (i = 0; i < NUM_CHANNELS_THREE; ++i) {
99-
result[i] = 1.0 / (1.0 + exp(-result[i]));
100-
}
101-
return half3(result[0]*viewdir.a+(1.0-viewdir.a),
102-
result[1]*viewdir.a+(1.0-viewdir.a),
103-
result[2]*viewdir.a+(1.0-viewdir.a));
43+
float4x4 intermediate_one = { BIAS_LIST_ZERO };
44+
intermediate_one += f0.r * float4x4(__W0_0__)
45+
+ f0.g * float4x4(__W0_1__)
46+
+ f0.b * float4x4(__W0_2__)
47+
+ f0.a * float4x4(__W0_3__)
48+
+ f1.r * float4x4(__W0_4__)
49+
+ f1.g * float4x4(__W0_5__)
50+
+ f1.b * float4x4(__W0_6__)
51+
+ f1.a * float4x4(__W0_7__)
52+
+ viewdir.r * float4x4(__W0_8__)
53+
+ viewdir.g * float4x4(__W0_9__)
54+
+ viewdir.b * float4x4(__W0_10__);
55+
intermediate_one[0] = max(intermediate_one[0], 0.0);
56+
intermediate_one[1] = max(intermediate_one[1], 0.0);
57+
intermediate_one[2] = max(intermediate_one[2], 0.0);
58+
intermediate_one[3] = max(intermediate_one[3], 0.0);
59+
float4x4 intermediate_two = float4x4(
60+
BIAS_LIST_ONE
61+
);
62+
intermediate_two += intermediate_one[0][0] * float4x4(__W1_0__)
63+
+ intermediate_one[0][1] * float4x4(__W1_1__)
64+
+ intermediate_one[0][2] * float4x4(__W1_2__)
65+
+ intermediate_one[0][3] * float4x4(__W1_3__)
66+
+ intermediate_one[1][0] * float4x4(__W1_4__)
67+
+ intermediate_one[1][1] * float4x4(__W1_5__)
68+
+ intermediate_one[1][2] * float4x4(__W1_6__)
69+
+ intermediate_one[1][3] * float4x4(__W1_7__)
70+
+ intermediate_one[2][0] * float4x4(__W1_8__)
71+
+ intermediate_one[2][1] * float4x4(__W1_9__)
72+
+ intermediate_one[2][2] * float4x4(__W1_10__)
73+
+ intermediate_one[2][3] * float4x4(__W1_11__)
74+
+ intermediate_one[3][0] * float4x4(__W1_12__)
75+
+ intermediate_one[3][1] * float4x4(__W1_13__)
76+
+ intermediate_one[3][2] * float4x4(__W1_14__)
77+
+ intermediate_one[3][3] * float4x4(__W1_15__);
78+
intermediate_two[0] = max(intermediate_two[0], 0.0);
79+
intermediate_two[1] = max(intermediate_two[1], 0.0);
80+
intermediate_two[2] = max(intermediate_two[2], 0.0);
81+
intermediate_two[3] = max(intermediate_two[3], 0.0);
82+
float3 result = float3(
83+
BIAS_LIST_TWO
84+
);
85+
result += intermediate_two[0][0] * float3(__W2_0__)
86+
+ intermediate_two[0][1] * float3(__W2_1__)
87+
+ intermediate_two[0][2] * float3(__W2_2__)
88+
+ intermediate_two[0][3] * float3(__W2_3__)
89+
+ intermediate_two[1][0] * float3(__W2_4__)
90+
+ intermediate_two[1][1] * float3(__W2_5__)
91+
+ intermediate_two[1][2] * float3(__W2_6__)
92+
+ intermediate_two[1][3] * float3(__W2_7__)
93+
+ intermediate_two[2][0] * float3(__W2_8__)
94+
+ intermediate_two[2][1] * float3(__W2_9__)
95+
+ intermediate_two[2][2] * float3(__W2_10__)
96+
+ intermediate_two[2][3] * float3(__W2_11__)
97+
+ intermediate_two[3][0] * float3(__W2_12__)
98+
+ intermediate_two[3][1] * float3(__W2_13__)
99+
+ intermediate_two[3][2] * float3(__W2_14__)
100+
+ intermediate_two[3][3] * float3(__W2_15__);
101+
result = 1.0 / (1.0 + exp(-result));
102+
return result*viewdir.a+(1.0-viewdir.a);
104103
}
105104
ENDCG
106105
@@ -120,10 +119,9 @@ fixed4 frag(v2f i) : SV_Target {
120119
fixed4 diffuse1 = tex2D( tDiffuse1x, i.uv );
121120
fixed4 rayDir = fixed4(normalize(i.rayDirection), 1.0);
122121
123-
//deal with iphone
124-
diffuse0.a = diffuse0.a*2.0-1.0;
125-
diffuse1.a = diffuse1.a*2.0-1.0;
126-
rayDir.a = rayDir.a*2.0-1.0;
122+
// normalize range to [-1, 1]
123+
diffuse0.a = diffuse0.a * 2.0 - 1.0;
124+
diffuse1.a = diffuse1.a * 2.0 - 1.0;
127125
128126
fixed4 fragColor;
129127
fragColor.rgb = evaluateNetwork(diffuse0,diffuse1,rayDir);

Editor/WebRequestAsyncUtility.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,11 @@ private static UnityWebRequest GetRequest(string url, HTTPVerb verb, string post
142142
webRequest = UnityWebRequest.Get(url);
143143
break;
144144
case HTTPVerb.POST:
145+
#if UNITY_2022_2_OR_NEWER
146+
webRequest = UnityWebRequest.PostWwwForm(url, postData);
147+
#else
145148
webRequest = UnityWebRequest.Post(url, postData);
149+
#endif
146150
byte[] rawBody = Encoding.UTF8.GetBytes(postData);
147151
webRequest.uploadHandler = new UploadHandlerRaw(rawBody);
148152
webRequest.downloadHandler = new DownloadHandlerBuffer();

0 commit comments

Comments
 (0)