Skip to content

Commit d5cadb0

Browse files
Finish implementing operator overloading
1 parent b5e9c62 commit d5cadb0

File tree

1 file changed

+156
-58
lines changed

1 file changed

+156
-58
lines changed

src/+otp/RHS.m

Lines changed: 156 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121
HessianAdjointVectorProduct
2222
OnEvent
2323
end
24+
25+
properties (Dependent)
26+
JacobianMatrix
27+
JacobianFunction
28+
MassMatrix
29+
MassFunction
30+
end
2431

2532
methods
2633
function obj = RHS(F, varargin)
@@ -33,73 +40,68 @@
3340
obj.(f) = extras.(f);
3441
end
3542
end
36-
37-
function newRHS = subsref(obj, vs)
38-
if strcmp(vs(1).type, '()')
39-
newF = @(t, y) subsref(obj.F(t, y), vs);
40-
41-
newJac = [];
42-
if ~isempty(obj.Jacobian)
43-
newJac = @(t, y) subsref(obj.Jacobian(t, y), vs);
44-
end
45-
newJacvp = [];
46-
if ~isempty(obj.JacobianVectorProduct)
47-
newJacvp = @(t, y, v) subsref(obj.JacobianVectorProduct(t, y, v), vs);
48-
end
49-
50-
vectorized = obj.Vectorized;
51-
52-
newRHS = otp.RHS(newF, ...
53-
'Jacobian', newJac, ...
54-
'JacobianVectorProduct', newJacvp, ...
55-
'Vectorized', vectorized);
56-
else
57-
newRHS = builtin('subsref', obj, vs);
58-
end
43+
44+
function mat = get.JacobianMatrix(obj)
45+
mat = obj.prop2Matrix(obj.Jacobian);
5946
end
60-
61-
function newRHS = plus(obj, other)
62-
objF = obj.F;
63-
otherF = other.F;
64-
newF = @(t, y) objF(t, y) + otherF(t, y);
65-
newRHS = otp.RHS(newF);
47+
48+
function fun = get.JacobianFunction(obj)
49+
fun = obj.prop2Function(obj.Jacobian);
50+
end
51+
52+
function mat = get.MassMatrix(obj)
53+
mat = obj.prop2Matrix(obj.Mass);
54+
end
55+
56+
function fun = get.MassFunction(obj)
57+
fun = obj.prop2Function(obj.Mass);
58+
end
59+
60+
function obj = uplus(obj)
61+
end
62+
63+
function newRHS = uminus(obj)
64+
newRHS = mtimes(-1, obj);
6665
end
6766

68-
function newRHS = vertcat(varargin)
69-
newF = @(t, y) [];
70-
newJac = @(t, y) [];
71-
newJacvp = @(t, y, v) [];
67+
function newRHS = plus(obj1, obj2)
68+
newRHS = applyOp(obj1, obj2, @plus, 2);
69+
end
7270

73-
for i = 1:numel(varargin)
74-
oldRHS = varargin{i};
75-
oldF = oldRHS.F;
71+
function newRHS = minus(obj1, obj2)
72+
newRHS = applyOp(obj1, obj2, @minus, 2);
73+
end
7674

77-
newF = @(t, y) [newF(t, y); oldF(t, y)];
75+
function newRHS = mtimes(obj1, obj2)
76+
newRHS = applyOp(obj1, obj2, @mtimes, 1);
77+
end
7878

79-
oldJac = oldRHS.Jacbobian;
80-
if ~isempty(oldJac)
81-
newJac = @(t, y) [newJac(t, y); oldJac(t, y)];
82-
end
79+
function newRHS = times(obj1, obj2)
80+
newRHS = applyOp(obj1, obj2, @times, 1);
81+
end
8382

84-
oldJacvp = oldRHS.JacobianVectorProduct;
85-
if ~isempty(oldJacvp)
86-
newJacvp = @(t, y, v) [newJacvp(t, y, v); oldJacvp(t, y, v)];
87-
end
83+
function newRHS = rdivide(obj1, obj2)
84+
newRHS = applyOp(obj1, obj2, @rdivide, 1);
85+
end
8886

89-
end
87+
function newRHS = ldivide(obj1, obj2)
88+
newRHS = applyOp(obj1, obj2, @ldivide, 1);
89+
end
9090

91-
91+
function newRHS = mrdivide(obj1, obj2)
92+
newRHS = applyOp(obj1, obj2, @mrdivide, 1);
93+
end
9294

93-
vectorized = obj.Vectorized;
95+
function newRHS = mldivide(obj1, obj2)
96+
newRHS = applyOp(obj1, obj2, @mldivide, 1);
97+
end
9498

95-
newRHS = otp.RHS(newF, ...
96-
'Jacobian', newJac, ...
97-
'JacobianVectorProduct', newJacvp, ...
98-
'Vectorized', vectorized);
99+
function newRHS = power(obj1, obj2)
100+
newRHS = applyOp(obj1, obj2, @power, 0);
99101
end
100102

101-
function s = size(~)
102-
s = [1, 1];
103+
function newRHS = mpower(obj1, obj2)
104+
newRHS = applyOp(obj1, obj2, @mpower, 0);
103105
end
104106

105107
function opts = odeset(obj, varargin)
@@ -118,10 +120,106 @@
118120
end
119121

120122
end
121-
122-
methods (Static)
123-
function newRHS = empty(obj, other)
124-
error('');
123+
124+
methods (Access = private)
125+
function mat = prop2Matrix(~, prop)
126+
if isa(prop, 'function_handle')
127+
mat = [];
128+
else
129+
mat = prop;
130+
end
131+
end
132+
133+
function fun = prop2Function(~, prop)
134+
if isa(prop, 'function_handle') || isempty(prop)
135+
fun = prop;
136+
else
137+
fun = @(varargin) prop;
138+
end
139+
end
140+
141+
function newRHS = applyOp(obj1, obj2, op, differentiability)
142+
% Events and NonNegative practically cannot be supported and are
143+
% always unset.
144+
145+
% Mass matrices introduce several difficulties. When singular, it
146+
% makes it infeasible to update InitialSlope, and therefore, it is
147+
% always unset. To avoid issues with two RHS' having different mass
148+
% matrices, only the primary RHS is used.
149+
[~, ~, props.Mass] = getProp(obj1, obj2, 'Mass');
150+
[~, ~, props.MassSingular] = getProp(obj1, obj2, 'MassSingular');
151+
[~, ~, props.MStateDependence] = getProp(obj1, obj2, ...
152+
'MStateDependence');
153+
[~, ~, props.MvPattern] = getProp(obj1, obj2, 'MvPattern');
154+
155+
% Merge derivatives
156+
props.Jacobian = mergeProp(obj1, obj2, op, differentiability, ...
157+
'Jacobian');
158+
props.JacobianVectorProduct = mergeProp(obj1, obj2, op, ...
159+
differentiability, 'JacobianVectorProduct');
160+
props.JacobianAdjointVectorProduct = mergeProp(obj1, obj2, op, ...
161+
differentiability, 'JacobianAdjointVectorProduct');
162+
props.PartialDerivativeParameters = mergeProp(obj1, obj2, op, ...
163+
differentiability, 'PartialDerivativeParameters');
164+
props.PartialDerivativeTime = mergeProp(obj1, obj2, op, ...
165+
differentiability, 'PartialDerivativeTime');
166+
props.HessianVectorProduct = mergeProp(obj1, obj2, op, ...
167+
differentiability, 'HessianVectorProduct');
168+
props.HessianAdjointVectorProduct = mergeProp(obj1, obj2, op, ...
169+
differentiability, 'HessianAdjointVectorProduct');
170+
171+
% JPattern requirs a special merge function
172+
if differentiability == 2
173+
patternOp = @or;
174+
else
175+
patternOp = @(j1, j2) op(j1 ~= 0, j2 ~=0) ~= 0;
176+
end
177+
props.JPattern = mergeProp(obj1, obj2, patternOp, ...
178+
differentiability, 'JPattern');
179+
180+
% Vectorization
181+
[v1, v2, vPrimary, numRHS] = getProp(obj1, obj2, 'Vectorized');
182+
if numRHS == 1 || strcmp({v1, v2}, 'on')
183+
props.Vectorized = vPrimary;
184+
end
185+
186+
newRHS = otp.RHS(mergeProp(obj1, obj2, op, inf, 'F'), props);
187+
end
188+
189+
function [obj1, obj2, primary, numRHS] = getProp(obj1, obj2, prop)
190+
numRHS = 1;
191+
192+
if isa(obj1, 'otp.RHS')
193+
obj1 = obj1.(prop);
194+
if isa(obj2, 'otp.RHS')
195+
obj2 = obj2.(prop);
196+
numRHS = 2;
197+
end
198+
primary = obj1;
199+
else
200+
obj2 = obj2.(prop);
201+
primary = obj2;
202+
end
203+
end
204+
205+
function p = mergeProp(obj1, obj2, op, differentiability, prop)
206+
[p1, p2, pPrimary, numRHS] = getProp(obj1, obj2, prop);
207+
208+
if isempty(p1) || isempty(p2) || numRHS > differentiability
209+
p = [];
210+
elseif numRHS == differentiability - 1
211+
p = pPrimary;
212+
elseif isa(p1, 'function_handle')
213+
if isa(p2, 'function_handle')
214+
p = @(varargin) op(p1(varargin{:}), p2(varargin{:}));
215+
else
216+
p = @(varargin) op(p1(varargin{:}), p2);
217+
end
218+
elseif isa(p2, 'function_handle')
219+
p = @(varargin) op(p1, p2(varargin{:}));
220+
else
221+
p = op(p1, p2);
222+
end
125223
end
126224
end
127225
end

0 commit comments

Comments
 (0)