@@ -39,7 +39,7 @@ paddle_error paddle_matrix_destroy(paddle_matrix mat) {
3939
4040paddle_error paddle_matrix_set_row (paddle_matrix mat,
4141 uint64_t rowID,
42- pd_real * rowArray) {
42+ paddle_real * rowArray) {
4343 if (mat == nullptr ) return kPD_NULLPTR ;
4444 auto ptr = cast (mat);
4545 if (ptr->mat == nullptr ) return kPD_NULLPTR ;
@@ -56,7 +56,7 @@ paddle_error paddle_matrix_set_row(paddle_matrix mat,
5656
5757paddle_error paddle_matrix_get_row (paddle_matrix mat,
5858 uint64_t rowID,
59- pd_real ** rawRowBuffer) {
59+ paddle_real ** rawRowBuffer) {
6060 if (mat == nullptr ) return kPD_NULLPTR ;
6161 auto ptr = cast (mat);
6262 if (ptr->mat == nullptr ) return kPD_NULLPTR ;
@@ -78,3 +78,46 @@ paddle_error paddle_matrix_get_shape(paddle_matrix mat,
7878 return kPD_NO_ERROR ;
7979}
8080}
81+
82+ paddle_matrix paddle_matrix_create_sparse (
83+ uint64_t height, uint64_t width, uint64_t nnz, bool isBinary, bool useGpu) {
84+ auto ptr = new paddle::capi::CMatrix ();
85+ ptr->mat = paddle::Matrix::createSparseMatrix (
86+ height,
87+ width,
88+ nnz,
89+ isBinary ? paddle::NO_VALUE : paddle::FLOAT_VALUE,
90+ paddle::SPARSE_CSR,
91+ false ,
92+ useGpu);
93+ return ptr;
94+ }
95+
96+ paddle_error paddle_matrix_sparse_copy_from (paddle_matrix mat,
97+ int * rowArray,
98+ uint64_t rowSize,
99+ int * colArray,
100+ uint64_t colSize,
101+ float * valueArray,
102+ uint64_t valueSize) {
103+ if (mat == nullptr ) return kPD_NULLPTR ;
104+ auto ptr = cast (mat);
105+ if (rowArray == nullptr || colArray == nullptr ||
106+ (valueSize != 0 && valueArray == nullptr ) || ptr->mat == nullptr ) {
107+ return kPD_NULLPTR ;
108+ }
109+ if (auto sparseMat = dynamic_cast <paddle::CpuSparseMatrix*>(ptr->mat .get ())) {
110+ std::vector<int > row (rowSize);
111+ row.assign (rowArray, rowArray + rowSize);
112+ std::vector<int > col (colSize);
113+ col.assign (colArray, colArray + colSize);
114+ std::vector<paddle_real> val (valueSize);
115+ if (valueSize) {
116+ val.assign (valueArray, valueArray + valueSize);
117+ }
118+ sparseMat->copyFrom (row, col, val);
119+ return kPD_NO_ERROR ;
120+ } else {
121+ return kPD_NOT_SUPPORTED ;
122+ }
123+ }
0 commit comments