diff --git a/README.md b/README.md index c96aed8..b1adade 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ C++ API to log data in tensorboard format. Only support `scalar`, `histogram`, ` ![text](./assets/text.jpg) ![embedding](./assets/embedding.png) ![multiple-image](./assets/multi-image.png) +![pr-curve](./assets/pr_curve.png) # Acknowledgement diff --git a/assets/pr_curve.png b/assets/pr_curve.png new file mode 100644 index 0000000..e542209 Binary files /dev/null and b/assets/pr_curve.png differ diff --git a/include/tensorboard_logger.h b/include/tensorboard_logger.h index e2c72ef..02908fd 100644 --- a/include/tensorboard_logger.h +++ b/include/tensorboard_logger.h @@ -8,6 +8,7 @@ #include "crc.h" #include "event.pb.h" +#include "plugin_pr_curve.pb.h" using tensorflow::Event; using tensorflow::Summary; @@ -44,9 +45,10 @@ class TensorBoardLogger { template int add_histogram(const std::string &tag, int step, const T *value, size_t num) { - if (bucket_limits_ == nullptr) { - generate_default_buckets(); - } + + double max_range = static_cast(*(std::max(value,value+num-1))); + double min_range = static_cast(*(std::min(value,value+num-1))); + generate_default_buckets({max_range, min_range}, num, false, true); std::vector counts(bucket_limits_->size(), 0); double min = std::numeric_limits::max(); @@ -140,9 +142,24 @@ class TensorBoardLogger { const std::vector &metadata = std::vector(), const std::string &metadata_filename = "", int step = 1 /* no effect */); - + int prcurve(const std::string tag, + const std::vectorlabels, + const std::vectorpredictions, + const int num_thresholds = 127, + std::vectorweights = {}, + const std::string &display_name = "", + const std::string &description = ""); private: - int generate_default_buckets(); + std::vector> compute_curve( + const std::vectorlabels, + const std::vectorpredictions, + int num_thresholds = 127, + std::vectorweights = {}); + int generate_default_buckets( + std::vector range = {(-1*1e-12), 1e20}, + size_t num_of_bins = 10, + bool ignore_outside_range = false, + bool regenerate = false); int add_event(int64_t step, Summary *summary); int write(Event &event); diff --git a/proto/plugin_pr_curve.proto b/proto/plugin_pr_curve.proto new file mode 100644 index 0000000..cbf369b --- /dev/null +++ b/proto/plugin_pr_curve.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +package tensorflow; + +message PrCurvePluginData { + // Version `0` is the only supported version. + int32 version = 1; + + uint32 num_thresholds = 2; +} \ No newline at end of file diff --git a/src/tensorboard_logger.cc b/src/tensorboard_logger.cc index 061aa2e..fe8d0ee 100644 --- a/src/tensorboard_logger.cc +++ b/src/tensorboard_logger.cc @@ -15,6 +15,7 @@ #include "event.pb.h" #include "projector_config.pb.h" +#include "plugin_pr_curve.pb.h" using namespace std; using google::protobuf::TextFormat; @@ -26,27 +27,31 @@ using tensorflow::ProjectorConfig; using tensorflow::Summary; using tensorflow::SummaryMetadata; using tensorflow::TensorProto; +using tensorflow::PrCurvePluginData; +using tensorflow::TensorShapeProto; // https://github.com/dmlc/tensorboard/blob/master/python/tensorboard/summary.py#L115 -int TensorBoardLogger::generate_default_buckets() { - if (bucket_limits_ == nullptr) { +int TensorBoardLogger::generate_default_buckets(std::vector range, + size_t num_of_bins, + bool ignore_outside_range, + bool regenerate ) { + if (bucket_limits_ == nullptr || regenerate == true) { bucket_limits_ = new vector; - vector pos_buckets, neg_buckets; - double v = 1e-12; - while (v < 1e20) { - pos_buckets.push_back(v); - neg_buckets.push_back(-v); - v *= 1.1; + double v = range[0]; + double width = (range[1] - range[0]) / num_of_bins ; + if (width == 0) + width = 1; + if(!ignore_outside_range) + bucket_limits_->push_back(numeric_limits::lowest()); + while (v <= range[1]) { + bucket_limits_->push_back(v); + v = v + width; + } + if(!ignore_outside_range) + { + bucket_limits_->push_back(numeric_limits::max()); } - pos_buckets.push_back(numeric_limits::max()); - neg_buckets.push_back(numeric_limits::lowest()); - - bucket_limits_->insert(bucket_limits_->end(), neg_buckets.rbegin(), - neg_buckets.rend()); - bucket_limits_->insert(bucket_limits_->end(), pos_buckets.begin(), - pos_buckets.end()); } - return 0; } @@ -243,6 +248,114 @@ int TensorBoardLogger::add_embedding( tensor_shape, step); } +std::vector> TensorBoardLogger::compute_curve( + const std::vectorlabels, + const std::vectorpredictions, + int num_thresholds, + std::vectorweights) +{ + // misbheaves when thresholds is greater than 127 + num_thresholds = min(num_thresholds,127); + double min_count = 1e-7; + std::vector> data; + while (weights.size() tp(bucket_limits_->size(), 0), fp(bucket_limits_->size(), 0); + + for (size_t i = 0; i < labels.size(); ++i) + { + float v = labels[i]; + int item = predictions[i] * (num_thresholds -1); + auto lb = + lower_bound(bucket_limits_->begin(), bucket_limits_->end(), item); + if(*lb != item) + lb--; + tp[lb - bucket_limits_->begin()] = tp[lb - bucket_limits_->begin()] + (v*weights[i]); + fp[lb - bucket_limits_->begin()] = fp[lb - bucket_limits_->begin()] + ((1-v)*weights[i]); + + } + + // Reverse cummulative sum + for(int i = tp.size() - 2; i >= 0 ;i--) + { + tp[i] = tp[i] + tp[i+1]; + fp[i] = fp[i] + fp[i+1]; + } + std::vector tn(tp.size()), fn(tp.size()), precision(tp.size()), recall(tp.size()); + for(size_t i = 0; i < tp.size() ;i++) + { + fn[i] = tp[0] - tp[i]; + tn[i] = fp[0] - fp[i]; + precision[i] = tp[i] / max(min_count,tp[i]+fp[i]); + recall[i] = tp[i] / max(min_count,tp[i]+fn[i]); + } + data.push_back(tp); + data.push_back(fp); + data.push_back(tn); + data.push_back(fn); + data.push_back(precision); + data.push_back(recall); + return data; +} +int TensorBoardLogger::prcurve( + const std::string tag, + const std::vectorlabels, + const std::vectorpredictions, + const int num_thresholds, + std::vectorweights, + const std::string &display_name, + const std::string &description) +{ + // Pr plugin + PrCurvePluginData *pr_curve_plugin = new PrCurvePluginData(); + pr_curve_plugin->set_version(0); + pr_curve_plugin->set_num_thresholds(num_thresholds); + std::string pr_curve_content; + pr_curve_plugin->SerializeToString(&pr_curve_content); + + // PluginMeta data + auto *plugin_data = new SummaryMetadata::PluginData(); + plugin_data->set_plugin_name("pr_curves"); + plugin_data->set_content(pr_curve_content); + + // Summary Meta data + auto *meta = new SummaryMetadata(); + meta->set_display_name(display_name == "" ? tag : display_name); + meta->set_summary_description(description); + meta->set_allocated_plugin_data(plugin_data); + + std::vector> data = + compute_curve(labels, predictions, num_thresholds, weights); + + // Prepare Tensor + auto *tensorshape = new TensorShapeProto(); + auto rowdim = tensorshape->add_dim(); + rowdim->set_size(data.size()); + auto coldim = tensorshape->add_dim(); + coldim->set_size(data[0].size()); + auto *tensor = new TensorProto(); + tensor->set_dtype(tensorflow::DataType::DT_DOUBLE); + tensor->set_allocated_tensor_shape(tensorshape); + for(int i=0;iadd_double_val(data[i][j]); + } + } + + auto *summary = new Summary(); + auto *v = summary->add_value(); + v->set_tag(tag); + v->set_allocated_tensor(tensor); + v->set_allocated_metadata(meta); + + return add_event(0, summary); +} + int TensorBoardLogger::add_embedding(const std::string &tensor_name, const float *tensor, const std::vector &tensor_shape, diff --git a/tests/test_tensorboard_logger.cc b/tests/test_tensorboard_logger.cc index 584fd2a..2412cc4 100644 --- a/tests/test_tensorboard_logger.cc +++ b/tests/test_tensorboard_logger.cc @@ -106,7 +106,18 @@ int test_log(const char* log_file) { tensor_shape.push_back(tensor[0].size()); logger.add_embedding("binary tensor 1d", tensor_1d, tensor_shape, "tensor_1d.bin", meta, "binary_tensor_1d.tsv"); - delete[] tensor_1d; + delete[] tensor_1d; + + // test pr curver + vector labels, predictions; + for(int i=0;i<100;i++) + { + double item = (double) rand()/RAND_MAX; + double sem_item = (double) rand()/RAND_MAX; + labels.push_back(round(sem_item)); + predictions.push_back(item); + } + logger.prcurve("pr_curve",labels,predictions,127); return 0; }