Skip to content

Commit da21ffe

Browse files
authored
Merge pull request #23 from hahassan7/btagFramework
More changes
2 parents 2c31d63 + 9ce361a commit da21ffe

File tree

6 files changed

+203
-331
lines changed

6 files changed

+203
-331
lines changed

PWGJE/Core/JetTaggingUtilities.h

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,15 @@ enum JetTaggingSpecies {
4444
gluon = 5
4545
};
4646

47-
namespace jettaggingutilities
48-
{
49-
5047
enum TaggingMethodNonML {
5148
IPs = 0,
5249
IPs3D = 1,
5350
SV = 2,
5451
SV3D = 3
5552
};
5653

54+
namespace jettaggingutilities
55+
{
5756
const int cmTomum = 10000; // using cm -> #mum for impact parameter (dca)
5857

5958
struct BJetParams {
@@ -711,7 +710,7 @@ template <typename AnyCollision, typename AnalysisJet, typename AnyTracks, typen
711710
int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyTracks const&, AnyParticles const& particles, AnyOriginalParticles const&, std::unordered_map<std::string, std::vector<int>>& trkLabels, bool searchUpToQuark, float vtxResParam = 0.01 /* 0.01cm = 100um */, float trackPtMin = 0.5)
712711
{
713712
const auto& tracks = jet.template tracks_as<AnyTracks>();
714-
const int n_trks = tracks.size();
713+
const int nTrks = tracks.size();
715714

716715
// trkVtxIndex
717716

@@ -725,32 +724,32 @@ int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyT
725724
tempTrkVtxIndex.push_back(i++);
726725
}
727726
tempTrkVtxIndex.push_back(i); // temporary index for PV
728-
if (n_trks < 1) { // the process should be done for n_trks == 1 as well
727+
if (nTrks < 1) { // the process should be done for nTrks == 1 as well
729728
trkLabels["trkVtxIndex"] = tempTrkVtxIndex;
730-
return n_trks;
729+
return nTrks;
731730
}
732731

733-
int n_pos = n_trks + 1;
734-
std::vector<float> dists(n_pos * (n_pos - 1) / 2);
735-
auto trk_pair_idx = [n_pos](int ti, int tj) {
736-
if (ti == tj || ti >= n_pos || tj >= n_pos || ti < 0 || tj < 0) {
732+
int nPos = nTrks + 1;
733+
std::vector<float> dists(nPos * (nPos - 1) / 2);
734+
auto trkPairIdx = [nPos](int ti, int tj) {
735+
if (ti == tj || ti >= nPos || tj >= nPos || ti < 0 || tj < 0) {
737736
LOGF(info, "Track pair index out of range");
738737
return -1;
739738
} else {
740-
return (ti < tj) ? (ti * n_pos - (ti * (ti + 1)) / 2 + tj - ti - 1) : (tj * n_pos - (tj * (tj + 1)) / 2 + ti - tj - 1);
739+
return (ti < tj) ? (ti * nPos - (ti * (ti + 1)) / 2 + tj - ti - 1) : (tj * nPos - (tj * (tj + 1)) / 2 + ti - tj - 1);
741740
}
742-
}; // index n_trks is for PV
741+
}; // index nTrks is for PV
743742

744-
for (int ti = 0; ti < n_pos - 1; ti++)
745-
for (int tj = ti + 1; tj < n_pos; tj++) {
743+
for (int ti = 0; ti < nPos - 1; ti++)
744+
for (int tj = ti + 1; tj < nPos; tj++) {
746745
std::array<float, 3> posi, posj;
747746

748-
if (tj < n_trks) {
747+
if (tj < nTrks) {
749748
if (tracks[tj].has_mcParticle()) {
750749
const auto& pj = tracks[tj].template mcParticle_as<AnyParticles>().template mcParticle_as<AnyOriginalParticles>();
751750
posj = std::array<float, 3>{pj.vx(), pj.vy(), pj.vz()};
752751
} else {
753-
dists[trk_pair_idx(ti, tj)] = std::numeric_limits<float>::max();
752+
dists[trkPairIdx(ti, tj)] = std::numeric_limits<float>::max();
754753
continue;
755754
}
756755
} else {
@@ -761,24 +760,24 @@ int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyT
761760
const auto& pi = tracks[ti].template mcParticle_as<AnyParticles>().template mcParticle_as<AnyOriginalParticles>();
762761
posi = std::array<float, 3>{pi.vx(), pi.vy(), pi.vz()};
763762
} else {
764-
dists[trk_pair_idx(ti, tj)] = std::numeric_limits<float>::max();
763+
dists[trkPairIdx(ti, tj)] = std::numeric_limits<float>::max();
765764
continue;
766765
}
767766

768-
dists[trk_pair_idx(ti, tj)] = RecoDecay::distance(posi, posj);
767+
dists[trkPairIdx(ti, tj)] = RecoDecay::distance(posi, posj);
769768
}
770769

771770
int clusteri = -1, clusterj = -1;
772-
float min_min_dist = -1.f; // If there is an not-merge-able min_dist pair, check the 2nd-min_dist pair.
771+
float minMinDist = -1.f; // If there is an not-merge-able minDist pair, check the 2nd-minDist pair.
773772
while (true) {
774773

775-
float min_dist = -1.f; // Get min_dist pair
776-
for (int ti = 0; ti < n_pos - 1; ti++)
777-
for (int tj = ti + 1; tj < n_pos; tj++)
774+
float minDist = -1.f; // Get minDist pair
775+
for (int ti = 0; ti < nPos - 1; ti++)
776+
for (int tj = ti + 1; tj < nPos; tj++)
778777
if (tempTrkVtxIndex[ti] != tempTrkVtxIndex[tj] && tempTrkVtxIndex[ti] >= 0 && tempTrkVtxIndex[tj] >= 0) {
779-
float dist = dists[trk_pair_idx(ti, tj)];
780-
if ((dist < min_dist || min_dist < 0.f) && dist > min_min_dist) {
781-
min_dist = dist;
778+
float dist = dists[trkPairIdx(ti, tj)];
779+
if ((dist < minDist || minDist < 0.f) && dist > minMinDist) {
780+
minDist = dist;
782781
clusteri = ti;
783782
clusterj = tj;
784783
}
@@ -787,59 +786,59 @@ int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyT
787786
break;
788787

789788
bool mrg = true; // Merge-ability check
790-
for (int ti = 0; ti < n_pos && mrg; ti++)
789+
for (int ti = 0; ti < nPos && mrg; ti++)
791790
if (tempTrkVtxIndex[ti] == tempTrkVtxIndex[clusteri] && tempTrkVtxIndex[ti] >= 0) {
792-
for (int tj = 0; tj < n_pos && mrg; tj++)
791+
for (int tj = 0; tj < nPos && mrg; tj++)
793792
if (tj != ti && tempTrkVtxIndex[tj] == tempTrkVtxIndex[clusterj] && tempTrkVtxIndex[tj] >= 0) {
794-
if (dists[trk_pair_idx(ti, tj)] > vtxResParam) { // If there is more distant pair compared to vtx_res between two clusters, they cannot be merged.
793+
if (dists[trkPairIdx(ti, tj)] > vtxResParam) { // If there is more distant pair compared to vtx_res between two clusters, they cannot be merged.
795794
mrg = false;
796-
min_min_dist = min_dist;
795+
minMinDist = minDist;
797796
}
798797
}
799798
}
800-
if (min_dist > vtxResParam || min_dist < 0.f)
799+
if (minDist > vtxResParam || minDist < 0.f)
801800
break;
802801

803802
if (mrg) { // Merge two clusters
804-
int old_index = tempTrkVtxIndex[clusterj];
805-
for (int t = 0; t < n_pos; t++)
806-
if (tempTrkVtxIndex[t] == old_index)
803+
int oldIndex = tempTrkVtxIndex[clusterj];
804+
for (int t = 0; t < nPos; t++)
805+
if (tempTrkVtxIndex[t] == oldIndex)
807806
tempTrkVtxIndex[t] = tempTrkVtxIndex[clusteri];
808807
}
809808
}
810809

811-
int n_vertices = 0;
810+
int nVertices = 0;
812811

813812
// Sort the indices from PV (as 0) to the most distant SV (as 1~).
814-
int idxPV = tempTrkVtxIndex[n_trks];
815-
for (int t = 0; t < n_trks; t++)
813+
int idxPV = tempTrkVtxIndex[nTrks];
814+
for (int t = 0; t < nTrks; t++)
816815
if (tempTrkVtxIndex[t] == idxPV) {
817816
tempTrkVtxIndex[t] = -2;
818-
n_vertices = 1; // There is a track originating from PV
817+
nVertices = 1; // There is a track originating from PV
819818
}
820819

821820
std::unordered_map<int, float> avgDistances;
822821
std::unordered_map<int, int> count;
823-
for (int t = 0; t < n_trks; t++) {
822+
for (int t = 0; t < nTrks; t++) {
824823
if (tempTrkVtxIndex[t] >= 0) {
825-
avgDistances[tempTrkVtxIndex[t]] += dists[trk_pair_idx(t, n_trks)];
824+
avgDistances[tempTrkVtxIndex[t]] += dists[trkPairIdx(t, nTrks)];
826825
count[tempTrkVtxIndex[t]]++;
827826
}
828827
}
829828

830-
trkLabels["trkVtxIndex"] = std::vector<int>(n_trks, -1);
831-
if (count.size() != 0) { // If there is any SV cluster not only PV cluster
832-
for (auto& [idx, avgDistance] : avgDistances)
829+
trkLabels["trkVtxIndex"] = std::vector<int>(nTrks, -1);
830+
if (count.size() != 0) { // If there is any SV cluster not only PV cluster
831+
for (auto& [idx, avgDistance] : avgDistances) // o2-linter: disable=const-ref-in-for-loop
833832
avgDistance /= count[idx];
834833

835-
n_vertices += avgDistances.size();
834+
nVertices += avgDistances.size();
836835

837836
std::vector<std::pair<int, float>> sortedIndices(avgDistances.begin(), avgDistances.end());
838837
std::sort(sortedIndices.begin(), sortedIndices.end(), [](const auto& a, const auto& b) { return a.second < b.second; });
839838
int rank = 1;
840839
for (auto const& [idx, avgDistance] : sortedIndices) {
841840
bool found = false;
842-
for (int t = 0; t < n_trks; t++)
841+
for (int t = 0; t < nTrks; t++)
843842
if (tempTrkVtxIndex[t] == idx) {
844843
trkLabels["trkVtxIndex"][t] = rank;
845844
found = true;
@@ -848,7 +847,7 @@ int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyT
848847
}
849848
}
850849

851-
for (int t = 0; t < n_trks; t++)
850+
for (int t = 0; t < nTrks; t++)
852851
if (tempTrkVtxIndex[t] == -2)
853852
trkLabels["trkVtxIndex"][t] = 0;
854853

@@ -868,7 +867,7 @@ int vertexClustering(AnyCollision const& collision, AnalysisJet const& jet, AnyT
868867
trkIdx++;
869868
}
870869

871-
return n_vertices;
870+
return nVertices;
872871
}
873872

874873
std::vector<std::vector<float>> getInputsForML(BJetParams jetparams, std::vector<BJetTrackParams>& tracksParams, std::vector<BJetSVParams>& svsParams, int maxJetConst = 10)

0 commit comments

Comments
 (0)