Skip to content

Commit ba669f5

Browse files
authored
sensitivity reg (#200)
* sensitivity reg * local model, Plots explicit * direct assign * block separation * fix block * fix computations * comments * dont reg bias * phrasing relu * fix dot
1 parent 3d7e385 commit ba669f5

File tree

2 files changed

+93
-57
lines changed

2 files changed

+93
-57
lines changed

docs/src/examples/custom-relu.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#md # [![](https://img.shields.io/badge/show-github-579ACA.svg)](@__REPO_ROOT_URL__/docs/src/examples/custom-relu.jl)
44

55
# We demonstrate how DiffOpt can be used to generate a simple neural network
6-
# unit - the ReLU layer. A neural network is created using Flux.jl which is
6+
# unit - the ReLU layer. A neural network is created using Flux.jl and
77
# trained on the MNIST dataset.
88

99
# This tutorial uses the following packages
@@ -15,6 +15,7 @@ import ChainRulesCore
1515
import Flux
1616
import Statistics
1717
import Base.Iterators: repeated
18+
using LinearAlgebra
1819

1920
# ## The ReLU and its derivative
2021

@@ -33,7 +34,7 @@ function matrix_relu(
3334
@objective(
3435
model,
3536
Min,
36-
x'x -2y[:, i]'x # x' Q x + q'x with Q = I, q = -2y
37+
dot(x, x) -2dot(y[:, i], x)
3738
)
3839
optimize!(model)
3940
_x[:, i] = value.(x)
@@ -52,23 +53,26 @@ function ChainRulesCore.rrule(
5253
function pullback_matrix_relu(dl_dx)
5354
## some value from the backpropagation (e.g., loss) is denoted by `l`
5455
## so `dl_dy` is the derivative of `l` wrt `y`
55-
x = model[:x] # load decision variable `x` into scope
56+
x = model[:x] ## load decision variable `x` into scope
5657
dl_dy = zeros(T, size(dl_dx))
57-
dl_dq = zeros(T, size(dl_dx)) # for step-by-step explanation
58+
dl_dq = zeros(T, size(dl_dx)) ## for step-by-step explanation
5859
for i in 1:size(y, 2)
60+
## set sensitivities
5961
MOI.set.(
6062
model,
6163
DiffOpt.BackwardInVariablePrimal(),
6264
x,
6365
dl_dx[:, i]
64-
) # set sensitivities
65-
DiffOpt.backward(model) # compute grad
66+
)
67+
## compute grad
68+
DiffOpt.backward(model)
69+
## return gradient wrt objective function parameters
6670
obj_exp = MOI.get(
6771
model,
6872
DiffOpt.BackwardOutObjective()
69-
) # return gradient wrt objective function parameters
70-
dl_dq[:, i] = JuMP.coefficient.(obj_exp, x) # coeff of `x` in q'x = -2y'x
71-
dq_dy = -2 # dq/dy = -2
73+
)
74+
dl_dq[:, i] = JuMP.coefficient.(obj_exp, x) ## coeff of `x` in q'x = -2y'x
75+
dq_dy = -2 ## dq/dy = -2
7276
dl_dy[:, i] = dl_dq[:, i] * dq_dy
7377
end
7478
return (ChainRulesCore.NoTangent(), dl_dy,)
@@ -84,8 +88,8 @@ imgs = MLDatasets.MNIST.traintensor(1:N)
8488
labels = MLDatasets.MNIST.trainlabels(1:N);
8589

8690
# Preprocessing
87-
train_X = float.(reshape(imgs, size(imgs, 1) * size(imgs, 2), N)) #stack all the images
88-
train_Y = Flux.onehotbatch(labels, 0:9); # just a common way to encode categorical variables
91+
train_X = float.(reshape(imgs, size(imgs, 1) * size(imgs, 2), N)) ## stack all the images
92+
train_Y = Flux.onehotbatch(labels, 0:9);
8993

9094
test_imgs = MLDatasets.MNIST.testtensor(1:N)
9195
test_X = float.(reshape(test_imgs, size(test_imgs, 1) * size(test_imgs, 2), N))
@@ -114,9 +118,10 @@ dataset = repeated((train_X, train_Y), epochs);
114118

115119
# Parameters for the network training
116120

117-
custom_loss(x, y) = Flux.crossentropy(m(x), y) # training loss function
118-
opt = Flux.ADAM(); # stochastic gradient descent variant to optimize weights of the neral network
119-
evalcb = () -> @show(custom_loss(train_X, train_Y)); # callback to show loss
121+
# training loss function, Flux optimizer
122+
custom_loss(x, y) = Flux.crossentropy(m(x), y)
123+
opt = Flux.ADAM()
124+
evalcb = () -> @show(custom_loss(train_X, train_Y))
120125

121126
# Train to optimize network parameters
122127

@@ -125,9 +130,10 @@ evalcb = () -> @show(custom_loss(train_X, train_Y)); # callback to show loss
125130
# Although our custom implementation takes time, it is able to reach similar
126131
# accuracy as the usual ReLU function implementation.
127132

128-
accuracy(x, y) = Statistics.mean(Flux.onecold(m(x)) .== Flux.onecold(y)); # average of correct guesses
133+
# Average of correct guesses
134+
accuracy(x, y) = Statistics.mean(Flux.onecold(m(x)) .== Flux.onecold(y));
129135

130-
# Train accuracy
136+
# Training accuracy
131137

132138
accuracy(train_X, train_Y)
133139

docs/src/examples/sensitivity-analysis-ridge.jl

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# ```math
2323
# \begin{split}
2424
# \begin{array} {ll}
25-
# \mbox{minimize} & e^{\top}e + \alpha (w^2 + b^2) \\
25+
# \mbox{minimize} & e^{\top}e + \alpha (w^2) \\
2626
# \mbox{s.t.} & e_{i} = y_{i} - w x_{i} - b \quad \quad i=1..N \\
2727
# \end{array}
2828
# \end{split}
@@ -36,15 +36,15 @@ import DiffOpt
3636
import Random
3737
import Ipopt
3838
import Plots
39-
import LinearAlgebra: normalize!, dot
39+
using LinearAlgebra: dot
4040

4141
# ## Define and solve the problem
4242

4343
# Construct a set of noisy (guassian) data points around a line.
4444

4545
Random.seed!(42)
4646

47-
N = 100
47+
N = 150
4848

4949
w = 2 * abs(randn())
5050
b = rand()
@@ -64,76 +64,106 @@ function fit_ridge(X, Y, alpha = 0.1)
6464
set_silent(model)
6565
@variable(model, w) # angular coefficient
6666
@variable(model, b) # linear coefficient
67-
@variable(model, e[1:N]) # approximation error
68-
## constraint defining approximation error
69-
@constraint(model, cons[i=1:N], e[i] == Y[i] - w * X[i] - b)
67+
## expression defining approximation error
68+
@expression(model, e[i=1:N], Y[i] - w * X[i] - b)
7069
## objective minimizing squared error and ridge penalty
7170
@objective(
7271
model,
7372
Min,
74-
dot(e, e) + alpha * (sum(w * w) + sum(b * b)),
73+
1 / N * dot(e, e) + alpha * (w^2),
7574
)
7675
optimize!(model)
77-
return model, w, b, cons # return model, variables and constraints references
76+
return model, w, b # return model & variables
7877
end
7978

8079

81-
# Train on the data generated.
80+
# Plot the data points and the fitted line for different alpha values
8281

83-
model, w, b, cons = fit_ridge(X, Y)
84-
ŵ, b̂ = value(w), value(b)
85-
86-
# We can visualize the approximating line.
87-
88-
p = Plots.scatter(X, Y, label="")
82+
p = Plots.scatter(X, Y, label=nothing, legend=:topleft)
8983
mi, ma = minimum(X), maximum(X)
90-
Plots.plot!(p, [mi, ma], [mi *+ b̂, ma *+ b̂], color=:red, label="")
84+
Plots.title!("Fitted lines and points")
9185

86+
for alpha in 0.5:0.5:1.5
87+
local model, w, b = fit_ridge(X, Y, alpha)
88+
= value(w)
89+
= value(b)
90+
Plots.plot!(p, [mi, ma], [mi *+ b̂, ma *+ b̂], label="alpha=$alpha", width=2)
91+
end
92+
p
9293

9394
# ## Differentiate
9495

9596
# Now that we've solved the problem, we can compute the sensitivity of optimal
96-
# values of the angular coefficient `w` with
97+
# values of the slope `w` with
9798
# respect to perturbations in the data points (`x`,`y`).
9899

99-
# Begin differentiating the model.
100-
# analogous to varying θ in the expression:
101-
# ```math
102-
# e_i = (y_{i} + \theta_{y_i}) - w (x_{i} + \theta_{x_{i}}) - b
103-
# ```
100+
alpha = 0.4
101+
model, w, b = fit_ridge(X, Y, alpha)
102+
= value(w)
103+
= value(b)
104+
105+
# We first compute sensitivity of the slope with respect to a perturbation of the independent
106+
# variable `x`.
104107

105-
= zero(X)
108+
# Recalling that the points $(x_i, y_i)$ appear in the objective function as:
109+
# `(yi - b - w*xi)^2`, the `DiffOpt.ForwardInObjective` attribute must be set accordingly,
110+
# with the terms multiplying the parameter in the objective.
111+
# When considering the perturbation of a parameter θ, `DiffOpt.ForwardInObjective()` takes in the expression in the
112+
# objective that multiplies θ.
113+
# If θ appears with a quadratic and a linear form: `θ^2 a x + θ b y`, then the expression to pass to
114+
# `ForwardInObjective` is `2θ a x + b y`.
115+
116+
# Sensitivity with respect to x and y
117+
118+
∇y = zero(X)
119+
∇x = zero(X)
106120
for i in 1:N
107-
for j in 1:N
108-
MOI.set(
109-
model,
110-
DiffOpt.ForwardInConstraint(),
111-
cons[j],
112-
i == j ? index(w) + 1.0 : 0.0 * index(w)
113-
)
114-
end
121+
MOI.set(
122+
model,
123+
DiffOpt.ForwardInObjective(),
124+
2w^2 * X[i] + 2b * w - 2 * w * Y[i]
125+
)
126+
DiffOpt.forward(model)
127+
∇x[i] = MOI.get(
128+
model,
129+
DiffOpt.ForwardOutVariablePrimal(),
130+
w
131+
)
132+
MOI.set(
133+
model,
134+
DiffOpt.ForwardInObjective(),
135+
(2Y[i] - 2b - 2w * X[i]),
136+
)
115137
DiffOpt.forward(model)
116-
dw = MOI.get(
138+
∇y[i] = MOI.get(
117139
model,
118140
DiffOpt.ForwardOutVariablePrimal(),
119141
w
120142
)
121-
∇[i] = abs(dw)
122143
end
123144

124-
normalize!(∇);
145+
# Visualize point sensitivities with respect to regression points.
146+
147+
p = Plots.scatter(
148+
X, Y,
149+
color = [dw < 0 ? :blue : :red for dw in ∇x],
150+
markersize = [5 * abs(dw) + 1.2 for dw in ∇x],
151+
label = ""
152+
)
153+
mi, ma = minimum(X), maximum(X)
154+
Plots.plot!(p, [mi, ma], [mi *+ b̂, ma *+ b̂], color = :blue, label = "")
155+
Plots.title!("Regression slope sensitivity with respect to x")
125156

126-
# Visualize point sensitivities with respect to regressing line.
127-
# Note that the gradients are normalized.
157+
#
128158

129159
p = Plots.scatter(
130160
X, Y,
131-
color = [x > 0 ? :red : :blue for x in ∇],
132-
markersize = [25 * abs(x) for x in ∇],
161+
color = [dw < 0 ? :blue : :red for dw iny],
162+
markersize = [5 * abs(dw) + 1.2 for dw iny],
133163
label = ""
134164
)
135165
mi, ma = minimum(X), maximum(X)
136-
Plots.plot!(p, [mi, ma], [mi *+ b̂, ma *+ b̂], color = :red, label = "")
166+
Plots.plot!(p, [mi, ma], [mi *+ b̂, ma *+ b̂], color = :blue, label = "")
167+
Plots.title!("Regression slope sensitivity with respect to y")
137168

138-
# Note the points in the extremes of the line segment are larger because
139-
# moving those points has a stronger effect on the angular coefficient of the line.
169+
# Note the points with less central `x` values induce a greater y sensitivity of the slope.

0 commit comments

Comments
 (0)