77#include < vector>
88
99// #include "preprocessing.hpp"
10+ #include " b64.cpp"
1011#include " flux.hpp"
12+ #include " json.hpp"
1113#include " stable-diffusion.h"
1214
1315#define STB_IMAGE_IMPLEMENTATION
@@ -49,7 +51,6 @@ const char* schedule_str[] = {
4951 " ays" ,
5052};
5153
52-
5354enum SDMode {
5455 TXT2IMG,
5556 IMG2IMG,
@@ -86,7 +87,6 @@ struct SDParams {
8687 int height = 512 ;
8788 int batch_count = 1 ;
8889
89-
9090 sample_method_t sample_method = EULER_A;
9191 schedule_t schedule = DEFAULT;
9292 int sample_steps = 20 ;
@@ -100,9 +100,9 @@ struct SDParams {
100100 bool vae_on_cpu = false ;
101101 bool color = false ;
102102
103- // server things
104- int port = 8080 ;
105- std::string host = " 127.0.0.1" ;
103+ // server things
104+ int port = 8080 ;
105+ std::string host = " 127.0.0.1" ;
106106};
107107
108108void print_params (SDParams params) {
@@ -227,7 +227,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
227227 break ;
228228 }
229229 params.vae_path = argv[i];
230- // TODO Tiny AE
230+ // TODO Tiny AE
231231 } else if (arg == " --type" ) {
232232 if (++i >= argc) {
233233 invalid_arg = true ;
@@ -565,27 +565,113 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
565565 fflush (out_stream);
566566}
567567
568- static void log_server_request (const httplib::Request & req, const httplib::Response & res) {
568+ static void log_server_request (const httplib::Request& req, const httplib::Response& res) {
569569 printf (" request: %s %s (%s)\n " , req.method .c_str (), req.path .c_str (), req.body .c_str ());
570570}
571571
572+ void parseJsonPrompt (std::string json_str, SDParams* params) {
573+ using namespace nlohmann ;
574+ json payload = json::parse (json_str);
575+ // if no exception, the request is a json object
576+ // now we try to get the new param values from the payload object
577+ // const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path
578+ try {
579+ std::string prompt = payload[" prompt" ];
580+ params->prompt = prompt;
581+ } catch (...) {
582+ }
583+ try {
584+ std::string negative_prompt = payload[" negative_prompt" ];
585+ params->negative_prompt = negative_prompt;
586+ } catch (...) {
587+ }
588+ try {
589+ int clip_skip = payload[" clip_skip" ];
590+ params->clip_skip = clip_skip;
591+ } catch (...) {
592+ }
593+ try {
594+ float cfg_scale = payload[" cfg_scale" ];
595+ params->cfg_scale = cfg_scale;
596+ } catch (...) {
597+ }
598+ try {
599+ float guidance = payload[" guidance" ];
600+ params->guidance = guidance;
601+ } catch (...) {
602+ }
603+ try {
604+ int width = payload[" width" ];
605+ params->width = width;
606+ } catch (...) {
607+ }
608+ try {
609+ int height = payload[" height" ];
610+ params->height = height;
611+ } catch (...) {
612+ }
613+ try {
614+ std::string sample_method = payload[" sample_method" ];
615+ // TODO map to enum value
616+ LOG_WARN (" sample_method is not supported yet\n " );
617+ } catch (...) {
618+ }
619+ try {
620+ int sample_steps = payload[" sample_steps" ];
621+ params->sample_steps = sample_steps;
622+ } catch (...) {
623+ }
624+ try {
625+ int64_t seed = payload[" seed" ];
626+ params->seed = seed;
627+ } catch (...) {
628+ }
629+ try {
630+ int batch_count = payload[" batch_count" ];
631+ params->batch_count = batch_count;
632+ } catch (...) {
633+ }
634+
635+ try {
636+ std::string control_cond = payload[" control_cond" ];
637+ // TODO map to enum value
638+ LOG_WARN (" control_cond is not supported yet\n " );
639+ } catch (...) {
640+ }
641+ try {
642+ float control_strength = payload[" control_strength" ];
643+ } catch (...) {
644+ }
645+ try {
646+ float style_strength = payload[" style_strength" ];
647+ } catch (...) {
648+ }
649+ try {
650+ bool normalize_input = payload[" normalize_input" ];
651+ params->normalize_input = normalize_input;
652+ } catch (...) {
653+ }
654+ try {
655+ std::string input_id_images_path = payload[" input_id_images_path" ];
656+ // TODO replace with b64 image maybe?
657+ } catch (...) {
658+ }
659+ }
572660
573661int main (int argc, const char * argv[]) {
574662 SDParams params;
575663
576664 parse_args (argc, argv, params);
577665
578-
579666 sd_set_log_callback (sd_log_cb, (void *)¶ms);
580667
581668 if (params.verbose ) {
582669 print_params (params);
583670 printf (" %s" , sd_get_system_info ());
584671 }
585672
673+ bool vae_decode_only = true ;
586674
587- bool vae_decode_only = true ;
588-
589675 sd_ctx_t * sd_ctx = new_sd_ctx (params.model_path .c_str (),
590676 params.clip_l_path .c_str (),
591677 params.t5xxl_path .c_str (),
@@ -614,33 +700,48 @@ int main(int argc, const char* argv[]) {
614700
615701 int n_prompts = 0 ;
616702
617- const auto txt2imgRequest = [&sd_ctx, ¶ms, &n_prompts](const httplib::Request & req, httplib::Response & res) {
618- // TODO: proper payloads
619- std::string prompt = req.body ;
620- if (!prompt.empty ()){
621- params.prompt = prompt;
622- }else {
623- params.seed +=1 ;
703+ const auto txt2imgRequest = [&sd_ctx, ¶ms, &n_prompts](const httplib::Request& req, httplib::Response& res) {
704+ LOG_INFO (" raw body is: %s\n " , req.body .c_str ());
705+ // parse req.body as json using jsoncpp
706+ using json = nlohmann::json;
707+
708+ try {
709+ std::string json_str = req.body ;
710+ parseJsonPrompt (json_str, ¶ms);
711+ } catch (json::parse_error& e) {
712+ // assume the request is just a prompt
713+ LOG_WARN (" Failed to parse json: %s\n Assuming it's just a prompt...\n " , e.what ());
714+ std::string prompt = req.body ;
715+ if (!prompt.empty ()) {
716+ params.prompt = prompt;
717+ } else {
718+ params.seed += 1 ;
719+ }
720+ } catch (...) {
721+ // Handle any other type of exception
722+ LOG_ERROR (" An unexpected error occurred\n " );
624723 }
724+ LOG_INFO (" prompt is: %s\n " , params.prompt .c_str ());
725+
625726 {
626727 sd_image_t * results;
627728 results = txt2img (sd_ctx,
628- params.prompt .c_str (),
629- params.negative_prompt .c_str (),
630- params.clip_skip ,
631- params.cfg_scale ,
632- params.guidance ,
633- params.width ,
634- params.height ,
635- params.sample_method ,
636- params.sample_steps ,
637- params.seed ,
638- params.batch_count ,
639- NULL ,
640- 1 ,
641- params.style_ratio ,
642- params.normalize_input ,
643- " " );
729+ params.prompt .c_str (),
730+ params.negative_prompt .c_str (),
731+ params.clip_skip ,
732+ params.cfg_scale ,
733+ params.guidance ,
734+ params.width ,
735+ params.height ,
736+ params.sample_method ,
737+ params.sample_steps ,
738+ params.seed ,
739+ params.batch_count ,
740+ NULL ,
741+ 1 ,
742+ params.style_ratio ,
743+ params.normalize_input ,
744+ " " );
644745
645746 if (results == NULL ) {
646747 printf (" generate failed\n " );
@@ -650,52 +751,67 @@ int main(int argc, const char* argv[]) {
650751
651752 size_t last = params.output_path .find_last_of (" ." );
652753 std::string dummy_name = last != std::string::npos ? params.output_path .substr (0 , last) : params.output_path ;
754+ json images_json = json::array ();
653755 for (int i = 0 ; i < params.batch_count ; i++) {
654756 if (results[i].data == NULL ) {
655757 continue ;
656758 }
657- std::string final_image_path = i > 0 ? dummy_name + " _" + std::to_string (i + 1 + n_prompts*params.batch_count ) + " .png" : dummy_name + " .png" ;
759+ // TODO allow disable save to disk
760+ std::string final_image_path = i > 0 ? dummy_name + " _" + std::to_string (i + 1 + n_prompts * params.batch_count ) + " .png" : dummy_name + " .png" ;
658761 stbi_write_png (final_image_path.c_str (), results[i].width , results[i].height , results[i].channel ,
659- results[i].data , 0 , get_image_params (params, params.seed + i).c_str ());
762+ results[i].data , 0 , get_image_params (params, params.seed + i).c_str ());
660763 printf (" save result image to '%s'\n " , final_image_path.c_str ());
661- // Todo: return base64 encoded image via websocket?
764+ // Todo: return base64 encoded image via httplib::Response& res
765+
766+ int len;
767+ unsigned char * png = stbi_write_png_to_mem ((const unsigned char *)results[i].data , 0 , results[i].width , results[i].height , results[i].channel , &len, NULL );
768+
769+ std::string data_str (png, png + len);
770+ std::string encoded_img = base64_encode (data_str);
771+
772+ images_json.push_back ({{" width" , results[i].width },
773+ {" height" , results[i].height },
774+ {" channel" , results[i].channel },
775+ {" data" , encoded_img},
776+ {" encoding" , " png" }});
777+
662778 free (results[i].data );
663779 results[i].data = NULL ;
664780 }
665781 free (results);
666782 n_prompts++;
783+ res.set_content (images_json.dump (), " application/json" );
667784 }
668785 return 0 ;
669786 };
670787
671-
672788 std::unique_ptr<httplib::Server> svr;
673789 svr.reset (new httplib::Server ());
674790 svr->set_default_headers ({{" Server" , " sd.cpp" }});
675791 // CORS preflight
676- svr->Options (R"( .*)" , [](const httplib::Request &, httplib::Response & res) {
792+ svr->Options (R"( .*)" , [](const httplib::Request&, httplib::Response& res) {
677793 // Access-Control-Allow-Origin is already set by middleware
678794 res.set_header (" Access-Control-Allow-Credentials" , " true" );
679- res.set_header (" Access-Control-Allow-Methods" , " POST" );
680- res.set_header (" Access-Control-Allow-Headers" , " *" );
681- return res.set_content (" " , " text/html" ); // blank response, no data
795+ res.set_header (" Access-Control-Allow-Methods" , " POST" );
796+ res.set_header (" Access-Control-Allow-Headers" , " *" );
797+ return res.set_content (" " , " text/html" ); // blank response, no data
682798 });
683799 svr->set_logger (log_server_request);
684800
685801 svr->Post (" /txt2img" , txt2imgRequest);
686802
687-
688803 // bind HTTP listen port, run the HTTP server in a thread
689804 if (!svr->bind_to_port (params.host , params.port )) {
690- // TODO: Error message
805+ // TODO: Error message
691806 return 1 ;
692- }
807+ }
693808 std::thread t ([&]() { svr->listen_after_bind (); });
694809 svr->wait_until_ready ();
695810
696- printf (" Server listening at %s:%d\n " ,params.host .c_str (),params.port );
811+ printf (" Server listening at %s:%d\n " , params.host .c_str (), params.port );
697812
698- while (1 );
813+ while (1 )
814+ ;
699815
700816 free_sd_ctx (sd_ctx);
701817
0 commit comments