2525using namespace Clockwork ;
2626
2727typedef struct {
28- double learning_rate;
29- double beta1;
30- double beta2;
31- double weight_decay;
28+ f64 learning_rate;
29+ f64 beta1;
30+ f64 beta2;
31+ f64 weight_decay;
3232} AdamWParams;
3333
3434void print_help (char ** argv) {
35- printf (" Usage: %s [options]\n\n " , argv[0 ]);
36- printf (" Options:\n " );
37- printf (" -h, --help Show this help message and exit.\n " );
38- printf (
39- " -t, --threads <number> Number of threads to use (type: uint32_t, default: %u).\n " ,
40- std::thread::hardware_concurrency () / 2 );
41- printf (
42- " -e, --epochs <number> Number of training epochs (type: int32_t, default: 1000).\n " );
43- printf (" -b, --batch <number> Batch size for training (type: size_t, default: %zu).\n " ,
44- static_cast <size_t >(16 * 16384 ));
45- printf (
46- " -d, --decay <value> Learning rate decay factor per epoch (type: double, default: 0.91).\n " );
47- printf (" \n AdamW Optimizer Parameters:\n " );
48- printf (" --lr <value> Learning rate (type: double, default: 10.0).\n " );
49- printf (" --beta1 <value> Beta1 parameter (type: double, default: 0.9).\n " );
50- printf (" --beta2 <value> Beta2 parameter (type: double, default: 0.999).\n " );
51- printf (" --weight_decay <value> Weight decay (type: double, default: 0.0).\n " );
35+ std::cout << " Usage: " << argv[0 ] << " [options]\n\n " ;
36+ std::cout << " Options:\n " ;
37+ std::cout << " -h, --help Show this help message and exit.\n " ;
38+ std::cout << " -t, --threads <number> Number of threads to use (type: uint32_t, default: "
39+ << std::thread::hardware_concurrency () / 2 << " .\n " ;
40+ std::cout
41+ << " -e, --epochs <number> Number of training epochs (type: int32_t, default: 1000).\n " ;
42+ std::cout << " -b, --batch <number> Batch size for training (type: size_t, default: "
43+ << static_cast <size_t >(16 * 16384 ) << " ).\n " ;
44+ std::cout
45+ << " -d, --decay <value> Learning rate decay factor per epoch (type: double, default: 0.91).\n " ;
46+ std::cout << " \n AdamW Optimizer Parameters:\n " ;
47+ std::cout << " --lr <value> Learning rate (type: double, default: 10.0).\n " ;
48+ std::cout << " --beta1 <value> Beta1 parameter (type: double, default: 0.9).\n " ;
49+ std::cout << " --beta2 <value> Beta2 parameter (type: double, default: 0.999).\n " ;
50+ std::cout << " --weight_decay <value> Weight decay (type: double, default: 0.0).\n " ;
5251}
5352
5453int main (int argc, char ** argv) {
@@ -58,7 +57,7 @@ int main(int argc, char** argv) {
5857 uint32_t thread_count_p = std::thread::hardware_concurrency () / 2 ;
5958 int32_t epochs_p = 1000 ;
6059 size_t batch_size_p = 16 * 16384 ;
61- double decay_p = 0.91 ;
60+ f64 decay_p = 0.91 ;
6261
6362 AdamWParams adam = {.learning_rate = 10.0 , .beta1 = 0.9 , .beta2 = 0.999 , .weight_decay = 0.0 };
6463
@@ -101,14 +100,11 @@ int main(int argc, char** argv) {
101100 // Check if it's a flag without a value or an unknown flag
102101 if (arg.rfind (" --" , 0 ) == 0 || arg.rfind (" -" , 0 ) == 0 ) {
103102 if (i + 1 >= argc || (argv[i + 1 ][0 ] == ' -' && !std::isdigit (argv[i + 1 ][1 ]))) {
104- printf (
105- " Warning! Argument '%s' has a missing value.\n Run %s --help to list all arguments." ,
106- argv[i], argv[0 ]);
107- exit (-1 );
103+ std::cout << " Warning! Argument '" << argv[i] << " ' has a missing value.\n Run "
104+ << argv[0 ] << " --help to list all arguments." : exit (-1 );
108105 } else {
109- printf (
110- " Warning! Arg not recognized: '%s'\n Run %s --help to list all arguments.\n " ,
111- argv[i], argv[0 ]);
106+ std::cout << " Warning! Arg not recognized: '" << argv[i] << " '\n Run "
107+ << argv[0 ] << " --help to list all arguments.\n " ;
112108 exit (-1 );
113109 }
114110 }
0 commit comments