Skip to content

Commit afe7e12

Browse files
authored
Merge pull request #226 from JingyuanZhang/master
feat(core): update model vars structure for high performance
2 parents 328b94b + f9c2d8c commit afe7e12

File tree

11 files changed

+102
-35
lines changed

11 files changed

+102
-35
lines changed

packages/paddlejs-converter/convertModel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# paddlepaddle运行程序实例
3434
program = None
3535
# 存放模型结构
36-
modelInfo = {"vars": [], "ops": [], "chunkNum": 0}
36+
modelInfo = {"vars": {}, "ops": [], "chunkNum": 0, "dataLayout": "nhwc"}
3737
# 存放参数数值(未排序)
3838
paramValuesDict = {}
3939

@@ -236,7 +236,7 @@ def organizeModelVariableInfo(result):
236236
# 将var信息按照顺序,添加到model info的vars中
237237
for key, value in varInfoOrderDict.items():
238238
value["name"] = key
239-
modelInfo["vars"].append(value)
239+
modelInfo["vars"][key] = value
240240
print("Organizing model variables info successfully.")
241241

242242
def organizeModelOpInfo():

packages/paddlejs-core/src/commons/utils.ts

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ export function getGlobalInterface(): any {
2828
/**
2929
* getOrMakeGlobalProperty
3030
* @param {string} key property key
31-
* @param { any } value property value
32-
* @returns { any } global property
31+
* @param {any} value property value
32+
* @returns {any} global property
3333
*/
3434
export function getOrMakeGlobalProperty(key: string, value?: Object | String | Number | Boolean)
3535
: any {
@@ -39,4 +39,63 @@ export function getOrMakeGlobalProperty(key: string, value?: Object | String | N
3939
}
4040
globalInterface[key] = value;
4141
return globalInterface[key];
42+
}
43+
44+
45+
/**
46+
* find target var by key
47+
* @param {Object | Array} vars model vars
48+
* @param {string} key var name
49+
* @returns {Object} var
50+
*/
51+
export function findVarByKey(vars, key) {
52+
if (Array.isArray(vars)) {
53+
return vars.find(item => item.name === key);
54+
}
55+
return vars[key];
56+
}
57+
58+
59+
/**
60+
* add var to vars
61+
* @param {Object | Array} vars model vars
62+
* @param {Object} item var
63+
*/
64+
export function AddItemToVars(vars, item) {
65+
const isVarsArray = Array.isArray(vars);
66+
const isItemArray = Array.isArray(item);
67+
if (isItemArray) {
68+
if (isVarsArray) {
69+
vars.splice(vars.length - 1, 0, ...item);
70+
}
71+
else {
72+
item.forEach(varItem => {
73+
vars[varItem.name] = varItem;
74+
});
75+
}
76+
return;
77+
}
78+
if (isVarsArray) {
79+
vars.push(item);
80+
return;
81+
}
82+
vars[item.name] = item;
83+
}
84+
85+
/**
86+
* traverse vars and deal var item
87+
* @param {Object | Array} vars model vars
88+
* @param {Function} callback deal var item
89+
*/
90+
export function traverseVars(vars, callback) {
91+
const isVarsArray = Array.isArray(vars);
92+
if (isVarsArray) {
93+
vars.forEach(item => {
94+
callback(item);
95+
});
96+
return;
97+
}
98+
Object.keys(vars).forEach(key => {
99+
callback(vars[key]);
100+
});
42101
}

packages/paddlejs-core/src/index.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { registerBackend, registerOp } from './globals';
44
import Env from './env';
55
import * as interfaces from './commons/interface';
66
import Transformer from './transform/transformer';
7+
import * as coreUtils from './commons/utils';
78

89
export {
910
Runner,
@@ -12,5 +13,6 @@ export {
1213
PaddlejsBackend,
1314
interfaces,
1415
Transformer,
15-
Env as env
16+
Env as env,
17+
coreUtils
1618
};

packages/paddlejs-core/src/loader.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
*/
44

55
import env from './env';
6-
import { Model, ModelVar } from './commons/interface';
6+
import { Model } from './commons/interface';
7+
import { traverseVars } from './commons/utils';
78

89
interface UrlConf {
910
dir: string;
@@ -124,12 +125,10 @@ export default class ModelLoader {
124125
});
125126
}
126127

127-
traverse(arr: ModelVar[], allChunksData: Float32Array) {
128+
traverse(vars, allChunksData: Float32Array) {
128129
let marker = 0; // 读到哪个位置了
129130
let len; // 当前op长度
130-
arr.filter(item => {
131-
return item.name;
132-
}).forEach(item => {
131+
traverseVars(vars, item => {
133132
len = item.shape.reduce((a, b) => a * b); // 长度为shape的乘积
134133
// 为了减少模型体积,模型转换工具不会导出非persistable的数据,这里只需要读取persistable的数据
135134
if (item.persistable) {

packages/paddlejs-core/src/opFactory/opDataBuilder.ts

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { GLOBALS } from '../globals';
33
import Tensor from './tensor';
44
import opBehaviors from './opBehaviors';
55
import * as Utils from './utils';
6+
import { findVarByKey } from '../commons/utils';
67

78
// model的名字和paddleJS的tensor名字mapping
89

@@ -67,12 +68,12 @@ export default class OpData {
6768
constructTensorData() {
6869
Object.keys(this.output).forEach(key => {
6970
this.output[key].forEach((name: string, index: number) => {
70-
this.output[key][index] = this.getTensorVar(name)[0];
71+
this.output[key][index] = this.getTensorVar(name);
7172
});
7273
});
7374

7475
Object.keys(this.input).forEach(key => {
75-
this.input[key] = this.getTensorVar(this.input[key][0]);
76+
this.input[key] = [this.getTensorVar(this.input[key][0])];
7677
});
7778

7879
for (const key in this.output) {
@@ -142,10 +143,11 @@ export default class OpData {
142143
}
143144

144145
getTensorVar(name: string) {
145-
const data = this.vars.filter(item => item.name === name || item.name === name.replace(/_packed$/, ''));
146-
if (data.length > 0 && name.endsWith('_packed')) {
147-
const packedData = Utils.packOpData(data[0], name);
148-
return [packedData];
146+
const varName = name.replace(/_packed$/, '');
147+
const data = findVarByKey(this.vars, varName);
148+
if (data && name.endsWith('_packed')) {
149+
const packedData = Utils.packOpData(data, name);
150+
return packedData;
149151
}
150152
return data;
151153
}

packages/paddlejs-core/src/runner.ts

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { Model, ModelConfig, InputFeed, ModelVar, GraphType } from './commons/in
44
import OpData from './opFactory/opDataBuilder';
55
import Tensor from './opFactory/tensor';
66
import { GLOBALS } from './globals';
7-
import { getGlobalInterface } from './commons/utils';
7+
import { getGlobalInterface, findVarByKey, AddItemToVars } from './commons/utils';
88
import MediaProcessor from './mediaProcessor';
99
import env from './env';
1010

@@ -174,7 +174,7 @@ export default class Runner {
174174
if (feedOpInputs.length > 1) {
175175
// 多输入
176176
preheatFeedData = feedOpInputs.map(inputName => {
177-
const feedInfo = vars.find(item => item.name === inputName);
177+
const feedInfo = findVarByKey(vars, inputName);
178178
const shape = feedInfo.shape;
179179
const [w, h, c = 3, n = 1] = shape.reverse();
180180

@@ -184,7 +184,7 @@ export default class Runner {
184184
}
185185
}
186186
else {
187-
preheatFeedData = vars.find(item => item.name === 'image');
187+
preheatFeedData = findVarByKey(vars, 'image');
188188
if (preheatFeedData) {
189189
preheatFeedData.data = new Float32Array(fc * fh * fw).fill(1.0);
190190
return;
@@ -197,7 +197,7 @@ export default class Runner {
197197
};
198198
}
199199

200-
vars.push(preheatFeedData);
200+
AddItemToVars(vars, preheatFeedData);
201201
}
202202

203203
updateFeedData(feed) {
@@ -266,9 +266,7 @@ export default class Runner {
266266

267267
async read() {
268268
const fetchOp = this.graphGenerator.getFetchExecutor();
269-
const fetchVar = this.model.vars.find(
270-
item => item.name === fetchOp.inputs.X[0]
271-
) as ModelVar;
269+
const fetchVar = findVarByKey(this.model.vars, fetchOp.inputs.X[0]) as ModelVar;
272270
const fetchInfo = {
273271
name: fetchVar.name,
274272
shape: fetchOp.attrs['origin_shape'] || fetchVar.shape

packages/paddlejs-core/src/transform/feedProcess.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
*/
44

55
import { ModelOp } from '../commons/interface';
6+
import { findVarByKey, AddItemToVars } from '../commons/utils';
67
import env from '../env';
78
import Transformer from './transformer';
89

@@ -25,7 +26,7 @@ export default class WebglFeedProcess extends Transformer {
2526
} = modelConfig;
2627

2728
// make img_pre_processed var
28-
const imgVar = vars.find(item => item.name === 'image');
29+
const imgVar = findVarByKey(vars, 'image');
2930
const [, , h, w] = imgVar.shape;
3031
imgVar.shape = [1, 1, h, w];
3132
const processImgVar = Object.assign({}, imgVar);
@@ -40,8 +41,7 @@ export default class WebglFeedProcess extends Transformer {
4041
originImgVar.persistable = false;
4142
delete originImgVar.data;
4243

43-
vars.push(originImgVar);
44-
vars.push(processImgVar);
44+
AddItemToVars(vars, [originImgVar, processImgVar]);
4545

4646
// change recieve_img op input
4747
const imageOriginInputOp = ops.find(item => {

packages/paddlejs-core/src/transform/nhwc2nchw.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import env from '../env';
66
import { ModelOp } from '../commons/interface';
7+
import { findVarByKey, AddItemToVars } from '../commons/utils';
78
import { formatShape } from '../opFactory/utils';
89
import Transformer from './transformer';
910

@@ -21,7 +22,7 @@ export default class nhwc2nchw extends Transformer {
2122
const [ops, vars] = args;
2223
const fetchOp = ops.find(item => item.type === 'fetch');
2324
const [inputName] = fetchOp.inputs.X;
24-
const fetchInputVar = vars.find(item => item.name === inputName);
25+
const fetchInputVar = findVarByKey(vars, inputName);
2526
const [n, c, h, w] = formatShape(fetchInputVar.shape);
2627

2728
// transform data from nhwc to nchw
@@ -47,6 +48,6 @@ export default class nhwc2nchw extends Transformer {
4748

4849
fetchOp.inputs.X = [FINAL_NCHW_OP_NAME];
4950
ops.push(...[nchwOp]);
50-
vars.push(...[nchwVar]);
51+
AddItemToVars(vars, nchwVar);
5152
}
5253
}

packages/paddlejs-core/src/transform/packOutOp.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import env from '../env';
66
import { ModelOp } from '../commons/interface';
7+
import { findVarByKey, AddItemToVars } from '../commons/utils';
78
import { formatShape } from '../opFactory/utils';
89
import Transformer from './transformer';
910

@@ -22,7 +23,7 @@ export default class PackOut extends Transformer {
2223
const [ops, vars] = args;
2324
const fetchOp = ops.find(item => item.type === 'fetch');
2425
const [inputName] = fetchOp.inputs.X;
25-
const fetchInputVar = vars.find(item => item.name === inputName);
26+
const fetchInputVar = findVarByKey(vars, inputName);
2627
const [n, c, h, w] = formatShape(fetchInputVar.shape);
2728

2829
// transform data from nhwc to nchw
@@ -68,7 +69,7 @@ export default class PackOut extends Transformer {
6869
fetchOp.inputs.X = [FINAL_PACK_OP_NAME];
6970
fetchOp.attrs['origin_shape'] = [n, c, h, w];
7071
ops.push(...[nchwOp, packOutOp]);
71-
vars.push(...[nchwVar, packOutVar]);
72+
AddItemToVars(vars, [nchwVar, packOutVar]);
7273
}
7374
}
7475

packages/paddlejs-core/src/transform/splitOp.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
* @file graph transformer
33
*/
44

5+
import { findVarByKey, AddItemToVars } from '../commons/utils';
56
import Transformer from './transformer';
67

8+
79
function getTensorShapeFromVals(name, vars) {
8-
const result = vars.filter(item => item.name === name);
9-
return result.length ? result[0].shape : [];
10+
const result = findVarByKey(vars, name);
11+
return result ? result.shape : [];
1012
}
1113

1214
function buildOutputVarInfo(inputs, outputShape, axis, vars) {
@@ -83,7 +85,8 @@ export default class SplitOp extends Transformer {
8385
// change outputname of next op
8486
opList[opLen - 1].outputs.Out = [outputName];
8587
ops.splice(index, 1, ...opList);
86-
vars.splice(vars.length - 1, 0, ...varList);
88+
89+
AddItemToVars(vars, varList);
8790
}
8891
}
8992
}

0 commit comments

Comments
 (0)