@@ -752,18 +752,15 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding(
752752 return ggml_timestep_embedding (ctx, timesteps, dim, max_period);
753753}
754754
755- // struct GGMLComputeGraph {
756- // virtual void init(struct ggml_context* ctx, ggml_type wtype) = 0;
757- // virtual std::string get_desc() = 0;
758- // virtual size_t get_params_mem_size() = 0;
759- // virtual size_t get_params_num() = 0;
760- // virtual struct ggml_cgraph* get_ggml_cgraph() = 0;
761- // };
762-
763- /*
764- #define MAX_PARAMS_TENSOR_NUM 10240
765- #define MAX_GRAPH_SIZE 10240
766- */
755+
756+ __STATIC_INLINE__ size_t ggml_tensor_num (ggml_context * ctx) {
757+ size_t num = 0 ;
758+ for (ggml_tensor * t = ggml_get_first_tensor (ctx); t != nullptr ; t = ggml_get_next_tensor (ctx, t)) {
759+ num++;
760+ }
761+ return num;
762+ }
763+
767764/* SDXL with LoRA requires more space */
768765#define MAX_PARAMS_TENSOR_NUM 15360
769766#define MAX_GRAPH_SIZE 15360
@@ -854,8 +851,6 @@ struct GGMLModule {
854851 }
855852
856853public:
857- virtual size_t get_params_mem_size () = 0;
858- virtual size_t get_params_num () = 0;
859854 virtual std::string get_desc () = 0;
860855
861856 GGMLModule (ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32)
@@ -876,7 +871,7 @@ struct GGMLModule {
876871 }
877872
878873 bool alloc_params_buffer () {
879- size_t num_tensors = get_params_num ( );
874+ size_t num_tensors = ggml_tensor_num (params_ctx );
880875 params_buffer = ggml_backend_alloc_ctx_tensors (params_ctx, backend);
881876 if (params_buffer == NULL ) {
882877 LOG_ERROR (" %s alloc params backend buffer failed" , get_desc ().c_str ());
@@ -898,6 +893,13 @@ struct GGMLModule {
898893 }
899894 }
900895
896+ size_t get_params_buffer_size () {
897+ if (params_buffer != NULL ) {
898+ return ggml_backend_buffer_get_size (params_buffer);
899+ }
900+ return 0 ;
901+ }
902+
901903 void free_compute_buffer () {
902904 if (compute_allocr != NULL ) {
903905 ggml_gallocr_free (compute_allocr);
@@ -968,19 +970,6 @@ struct GGMLModule {
968970};
969971
970972class GGMLBlock {
971- private:
972- static char temp_buffer[1024 * 1024 * 10 ];
973- ggml_context* get_temp_ctx () {
974- struct ggml_init_params params;
975- params.mem_size = sizeof (temp_buffer);
976- params.mem_buffer = temp_buffer;
977- params.no_alloc = true ;
978-
979- ggml_context* temp_ctx = ggml_init (params);
980- GGML_ASSERT (temp_ctx != NULL );
981- return temp_ctx;
982- }
983-
984973protected:
985974 typedef std::unordered_map<std::string, struct ggml_tensor *> ParameterMap;
986975 typedef std::unordered_map<std::string, std::shared_ptr<GGMLBlock>> GGMLBlockMap;
@@ -1003,14 +992,6 @@ class GGMLBlock {
1003992 init_params (ctx, wtype);
1004993 }
1005994
1006- std::tuple<size_t , size_t > get_params_info (ggml_type wtype) {
1007- ggml_context* temp_ctx = get_temp_ctx ();
1008- init (temp_ctx, wtype);
1009- size_t num_tensors = get_params_num ();
1010- size_t mem_size = get_params_mem_size ();
1011- return {num_tensors, mem_size};
1012- }
1013-
1014995 size_t get_params_num () {
1015996 size_t num_tensors = params.size ();
1016997 for (auto & pair : blocks) {
0 commit comments