Skip to content

Commit 97b9d70

Browse files
committed
added diagflat
1 parent eb41e69 commit 97b9d70

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

numojo/core/array_creation_routines.mojo

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,28 @@ fn full[
541541
return NDArray[dtype](shape, fill=tens_value)
542542

543543

544-
fn diagflat():
545-
pass
544+
fn diagflat[dtype: DType](inout v: NDArray[dtype], k: Int = 0) raises -> NDArray[dtype]:
545+
"""
546+
Generate a 2-D NDArray with the flattened input as the diagonal.
547+
548+
Parameters:
549+
dtype: Datatype of the NDArray elements.
550+
551+
Args:
552+
v: NDArray to be flattened and used as the diagonal.
553+
k: Diagonal offset.
554+
555+
Returns:
556+
A 2-D NDArray with the flattened input as the diagonal.
557+
"""
558+
v.reshape(v.ndshape.ndsize, 1)
559+
var n: Int= v.ndshape.ndsize + abs(k)
560+
var result: NDArray[dtype]= NDArray[dtype](n, n, random=False)
561+
562+
for i in range(n):
563+
print(n*i + i + k)
564+
result.store(n*i + i + k, v.data[i])
565+
return result
546566

547567

548568
fn tri():

0 commit comments

Comments
 (0)