Skip to content

Commit 03ba476

Browse files
Simplify RHS operation
1 parent d5cadb0 commit 03ba476

File tree

1 file changed

+70
-79
lines changed

1 file changed

+70
-79
lines changed

src/+otp/RHS.m

Lines changed: 70 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
classdef RHS
1+
classdef RHS
22
properties (SetAccess = private)
33
F
44

@@ -65,43 +65,43 @@
6565
end
6666

6767
function newRHS = plus(obj1, obj2)
68-
newRHS = applyOp(obj1, obj2, @plus, 2);
68+
newRHS = applyOp(obj1, obj2, @plus, @(r, ~) r, @(~, r) r, @plus);
6969
end
7070

7171
function newRHS = minus(obj1, obj2)
72-
newRHS = applyOp(obj1, obj2, @minus, 2);
72+
newRHS = applyOp(obj1, obj2, @minus, @(r, ~) r, @(~, r) r, @minus);
7373
end
7474

7575
function newRHS = mtimes(obj1, obj2)
76-
newRHS = applyOp(obj1, obj2, @mtimes, 1);
76+
newRHS = applyOp(obj1, obj2, @mtimes, @mtimes, @mtimes, []);
7777
end
7878

7979
function newRHS = times(obj1, obj2)
80-
newRHS = applyOp(obj1, obj2, @times, 1);
80+
newRHS = applyOp(obj1, obj2, @times, @times, @times, []);
8181
end
8282

8383
function newRHS = rdivide(obj1, obj2)
84-
newRHS = applyOp(obj1, obj2, @rdivide, 1);
84+
newRHS = applyOp(obj1, obj2, @rdivide, @rdivide, @rdivide, []);
8585
end
8686

8787
function newRHS = ldivide(obj1, obj2)
88-
newRHS = applyOp(obj1, obj2, @ldivide, 1);
88+
newRHS = applyOp(obj1, obj2, @ldivide, @ldivide, @ldivide, []);
8989
end
9090

9191
function newRHS = mrdivide(obj1, obj2)
92-
newRHS = applyOp(obj1, obj2, @mrdivide, 1);
92+
newRHS = applyOp(obj1, obj2, @mrdivide, @mrdivide, @mrdivide, []);
9393
end
9494

9595
function newRHS = mldivide(obj1, obj2)
96-
newRHS = applyOp(obj1, obj2, @mldivide, 1);
96+
newRHS = applyOp(obj1, obj2, @mldivide, @mldivide, @mldivide, []);
9797
end
9898

9999
function newRHS = power(obj1, obj2)
100-
newRHS = applyOp(obj1, obj2, @power, 0);
100+
newRHS = applyOp(obj1, obj2, @power, [], [], []);
101101
end
102102

103103
function newRHS = mpower(obj1, obj2)
104-
newRHS = applyOp(obj1, obj2, @mpower, 0);
104+
newRHS = applyOp(obj1, obj2, @mpower, [], [], []);
105105
end
106106

107107
function opts = odeset(obj, varargin)
@@ -122,93 +122,84 @@
122122
end
123123

124124
methods (Access = private)
125-
function mat = prop2Matrix(~, prop)
126-
if isa(prop, 'function_handle')
125+
function mat = prop2Matrix(~, p)
126+
if isa(p, 'function_handle')
127127
mat = [];
128128
else
129-
mat = prop;
129+
mat = p;
130130
end
131131
end
132132

133-
function fun = prop2Function(~, prop)
134-
if isa(prop, 'function_handle') || isempty(prop)
135-
fun = prop;
133+
function fun = prop2Function(~, p)
134+
if isa(p, 'function_handle') || isempty(p)
135+
fun = p;
136136
else
137-
fun = @(varargin) prop;
137+
fun = @(varargin) p;
138138
end
139139
end
140140

141-
function newRHS = applyOp(obj1, obj2, op, differentiability)
141+
function newRHS = applyOp(obj1, obj2, op, dOpLeft, dOpRight, dOpBoth)
142+
if isa(obj1, 'function_handle')
143+
obj1 = otp.RHS(obj1);
144+
elseif isa(obj2, 'function_handle')
145+
obj2 = otp.RHS(obj2);
146+
end
147+
148+
if isa(obj1, 'otp.RHS')
149+
primaryRHS = obj1;
150+
if isa(obj2, 'otp.RHS')
151+
f = otp.RHS.mergeProp(obj1.F, obj2.F, op);
152+
merge = @(p) otp.RHS.mergeProp(obj1.(p), obj2.(p), dOpBoth);
153+
154+
if strcmp(obj1.Vectorized, obj2.Vectorized)
155+
vectorized = obj1.Vectorized;
156+
end
157+
else
158+
f = otp.RHS.mergeProp(obj1.F, obj2, op);
159+
merge = @(p) otp.RHS.mergeProp(obj1.(p), obj2, dOpLeft);
160+
vectorized = obj1.Vectorized;
161+
end
162+
else
163+
primaryRHS = obj2;
164+
f = otp.RHS.mergeProp(obj1, obj2.F, op);
165+
merge = @(p) otp.RHS.mergeProp(obj1, obj2.(p), dOpRight);
166+
vectorized = obj2.Vectorized;
167+
end
168+
142169
% Events and NonNegative practically cannot be supported and are
143170
% always unset.
144171

172+
% JPattern is problematic to compute for division operators due to
173+
% singular patterns
174+
145175
% Mass matrices introduce several difficulties. When singular, it
146176
% makes it infeasible to update InitialSlope, and therefore, it is
147177
% always unset. To avoid issues with two RHS' having different mass
148178
% matrices, only the primary RHS is used.
149-
[~, ~, props.Mass] = getProp(obj1, obj2, 'Mass');
150-
[~, ~, props.MassSingular] = getProp(obj1, obj2, 'MassSingular');
151-
[~, ~, props.MStateDependence] = getProp(obj1, obj2, ...
152-
'MStateDependence');
153-
[~, ~, props.MvPattern] = getProp(obj1, obj2, 'MvPattern');
154-
155-
% Merge derivatives
156-
props.Jacobian = mergeProp(obj1, obj2, op, differentiability, ...
157-
'Jacobian');
158-
props.JacobianVectorProduct = mergeProp(obj1, obj2, op, ...
159-
differentiability, 'JacobianVectorProduct');
160-
props.JacobianAdjointVectorProduct = mergeProp(obj1, obj2, op, ...
161-
differentiability, 'JacobianAdjointVectorProduct');
162-
props.PartialDerivativeParameters = mergeProp(obj1, obj2, op, ...
163-
differentiability, 'PartialDerivativeParameters');
164-
props.PartialDerivativeTime = mergeProp(obj1, obj2, op, ...
165-
differentiability, 'PartialDerivativeTime');
166-
props.HessianVectorProduct = mergeProp(obj1, obj2, op, ...
167-
differentiability, 'HessianVectorProduct');
168-
props.HessianAdjointVectorProduct = mergeProp(obj1, obj2, op, ...
169-
differentiability, 'HessianAdjointVectorProduct');
170-
171-
% JPattern requirs a special merge function
172-
if differentiability == 2
173-
patternOp = @or;
174-
else
175-
patternOp = @(j1, j2) op(j1 ~= 0, j2 ~=0) ~= 0;
176-
end
177-
props.JPattern = mergeProp(obj1, obj2, patternOp, ...
178-
differentiability, 'JPattern');
179179

180-
% Vectorization
181-
[v1, v2, vPrimary, numRHS] = getProp(obj1, obj2, 'Vectorized');
182-
if numRHS == 1 || strcmp({v1, v2}, 'on')
183-
props.Vectorized = vPrimary;
184-
end
185-
186-
newRHS = otp.RHS(mergeProp(obj1, obj2, op, inf, 'F'), props);
187-
end
188-
189-
function [obj1, obj2, primary, numRHS] = getProp(obj1, obj2, prop)
190-
numRHS = 1;
191-
192-
if isa(obj1, 'otp.RHS')
193-
obj1 = obj1.(prop);
194-
if isa(obj2, 'otp.RHS')
195-
obj2 = obj2.(prop);
196-
numRHS = 2;
197-
end
198-
primary = obj1;
199-
else
200-
obj2 = obj2.(prop);
201-
primary = obj2;
202-
end
180+
newRHS = otp.RHS(f, ...
181+
'Mass', primaryRHS.Mass, ...
182+
'MassSingular', primaryRHS.MassSingular, ...
183+
'MStateDependence', primaryRHS.MStateDependence, ...
184+
'MvPattern', primaryRHS.MvPattern, ...
185+
'Jacobian', merge('Jacobian'), ...
186+
'JacobianVectorProduct', merge('JacobianVectorProduct'), ...
187+
'JacobianAdjointVectorProduct', ...
188+
merge('JacobianAdjointVectorProduct'), ...
189+
'PartialDerivativeParameters', ...
190+
merge('PartialDerivativeParameters'), ...
191+
'PartialDerivativeTime', merge('PartialDerivativeTime'), ...
192+
'HessianVectorProduct', merge('HessianVectorProduct'), ...
193+
'HessianAdjointVectorProduct', ...
194+
merge('HessianAdjointVectorProduct'), ...
195+
'Vectorized', vectorized);
203196
end
204-
205-
function p = mergeProp(obj1, obj2, op, differentiability, prop)
206-
[p1, p2, pPrimary, numRHS] = getProp(obj1, obj2, prop);
207-
208-
if isempty(p1) || isempty(p2) || numRHS > differentiability
197+
end
198+
199+
methods (Static, Access = private)
200+
function p = mergeProp(p1, p2, op)
201+
if isempty(p1) || isempty(p2) || isempty(op)
209202
p = [];
210-
elseif numRHS == differentiability - 1
211-
p = pPrimary;
212203
elseif isa(p1, 'function_handle')
213204
if isa(p2, 'function_handle')
214205
p = @(varargin) op(p1(varargin{:}), p2(varargin{:}));

0 commit comments

Comments
 (0)