Skip to content

Commit 5e8e79b

Browse files
authored
Update Sbp.jl
1 parent 8195b88 commit 5e8e79b

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

src/Sbp.jl

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ const SMALLEST_TREE_SEGMENT = 3
7373
const FAILURE_RECURSION_SIZE = -21
7474
const STD_DIM_SIZE = 7
7575
const ZERO_DIM = zeros(Float16, STD_DIM_SIZE)
76-
const EMPTY_DIM = Float16[]
76+
const EMPTY_DIM = Float16[typemax(Float16) for _ in 1:STD_DIM_SIZE]
7777

7878
using OrderedCollections
7979
using Random
@@ -87,10 +87,14 @@ export zero_unit_backward, mul_unit_backward, div_unit_backward, equal_unit_back
8787
export get_feature_dims_json, get_target_dim_json, retrieve_coeffs_based_on_similarity
8888
export ZERO_DIM
8989

90-
function equal_unit_forward(u1::Vector{Float16}, u2::Vector{Float16})
91-
if isempty(u1) || isempty(u2)
92-
return EMPTY_DIM
90+
@inline function has_inf16(u::Vector{Float16})
91+
@inbounds for x in u
92+
reinterpret(UInt16, x) == 0x7c00 && return true
9393
end
94+
return false
95+
end
96+
97+
function equal_unit_forward(u1::Vector{Float16}, u2::Vector{Float16})
9498
@inbounds return all(u1 .== u2) ? u1 : EMPTY_DIM
9599
end
96100

@@ -103,15 +107,12 @@ function arbitrary_unit_forward(u1::Vector{Float16})
103107
return u1
104108
end
105109

106-
function mul_unit_forward(u1::Vector{Float16}, u2::Vector{Float16})
107-
if isempty(u1) || isempty(u2)
108-
return EMPTY_DIM
109-
end
110+
function mul_unit_forward(u1::Vector{Float16}, u2::Vector{Float16})
110111
return u1 .+ u2
111112
end
112113

113114
function mul_unit_backward(u1::Vector{Float16}, u2::Vector{Float16}, expected_dim::Vector{Float16})
114-
if isempty(u2) && isempty(u1)
115+
if has_inf16(u2) && has_inf16(u1)
115116
if 0.5 < rand()
116117
lr = expected_dim
117118
rr = ZERO_DIM
@@ -120,9 +121,9 @@ function mul_unit_backward(u1::Vector{Float16}, u2::Vector{Float16}, expected_di
120121
lr = ZERO_DIM
121122
end
122123
return lr, rr
123-
elseif isempty(u2)
124+
elseif has_inf16(u2)
124125
return u1, expected_dim .- u1
125-
elseif isempty(u1)
126+
elseif has_inf16(u1)
126127
return expected_dim .- u2, u2
127128
else
128129
if isapprox(u1, u2, atol=eps(Float16))
@@ -139,15 +140,12 @@ end
139140

140141

141142
function div_unit_forward(u1::Vector{Float16}, u2::Vector{Float16})
142-
if isempty(u1) || isempty(u2)
143-
return EMPTY_DIM
144-
end
145143
return u1 .- u2
146144
end
147145

148146

149147
function div_unit_backward(u1::Vector{Float16}, u2::Vector{Float16}, expected_dim::Vector{Float16})
150-
if isempty(u2) && isempty(u1)
148+
if has_inf16(u2) && has_inf16(u1)
151149
if 0.5 < rand()
152150
lr = expected_dim
153151
rr = ZERO_DIM
@@ -156,9 +154,9 @@ function div_unit_backward(u1::Vector{Float16}, u2::Vector{Float16}, expected_di
156154
lr = ZERO_DIM
157155
end
158156
return lr, rr
159-
elseif isempty(u2)
157+
elseif has_inf16(u2)
160158
return u1, .-(expected_dim .+ u1)
161-
elseif isempty(u1)
159+
elseif has_inf16(u1)
162160
return expected_dim .+ u2, u2
163161
else
164162
if isapprox(u1, u2, atol=eps(Float16))
@@ -175,9 +173,6 @@ end
175173

176174

177175
function zero_unit_forward(u1::Vector{Float16})
178-
if isempty(u1)
179-
return EMPTY_DIM
180-
end
181176
@inbounds return all(u1 .== 0) ? ZERO_DIM .* u1 : EMPTY_DIM
182177
end
183178

0 commit comments

Comments
 (0)