Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions CLDConfig/CLDReconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import os
from Gaudi.Configuration import INFO, WARNING, DEBUG

from Configurables import k4DataSvc, MarlinProcessorWrapper
from k4MarlinWrapper.inputReader import create_reader, attach_edm4hep2lcio_conversion
from Configurables import MarlinProcessorWrapper, Lcio2EDM4hepTool
from k4FWCore.parseArgs import parser
from py_utils import SequenceLoader, attach_lcio2edm4hep_conversion, create_writer, parse_collection_patch_file
from k4FWCore import ApplicationMgr, IOSvc
from py_utils import SequenceLoader, attach_lcio2edm4hep_conversion, attach_edm4hep2lcio_conversion, create_writer, parse_collection_patch_file, attach_lcio2edm4hep_conversion_for_tagging

import ROOT
ROOT.gROOT.SetBatch(True)
Expand All @@ -33,6 +33,8 @@
parser_group.add_argument("--outputBasename", help="Basename of the output file(s)", default="output")
parser_group.add_argument("--trackingOnly", action="store_true", help="Run only track reconstruction", default=False)
parser_group.add_argument("--enableLCFIJet", action="store_true", help="Enable LCFIPlus jet clustering parts", default=False)
parser_group.add_argument("--enableMLJetTagger", action="store_true", help="Enable ML-based jet flavor tagging", default=False)
parser_group.add_argument("--MLJetTaggerModel", action="store", help="Type of ML model to use for inference", type=str, default="model_ParT_ecm240_cld_o2_v5")
parser_group.add_argument("--cms", action="store", help="Choose a Centre-of-Mass energy", default=240, choices=(91, 160, 240, 365), type=int)
parser_group.add_argument("--compactFile", help="Compact detector file to use", type=str, default=os.environ["K4GEO"] + "/FCCee/CLD/compact/CLD_o2_v07/CLD_o2_v07.xml")
tracking_group = parser_group.add_mutually_exclusive_group()
Expand All @@ -43,8 +45,14 @@
algList = []
svcList = []

evtsvc = k4DataSvc("EventDataSvc")
svcList.append(evtsvc)
if not reco_args.inputFiles:
print('WARNING: No input files specified, the CLD Reconstruction will fail')
reco_args.inputFiles = []

io_svc = IOSvc("IOSvc")
io_svc.Input = reco_args.inputFiles
io_svc.Output = f"{reco_args.outputBasename}.edm4hep.root"
svcList.append(io_svc)

CONFIG = {
"CalorimeterIntegrationTimeWindow": "10ns",
Expand All @@ -56,10 +64,9 @@
"OutputMode": "EDM4Hep",
"OutputModeChoices": ["LCIO", "EDM4hep"] #, "both"] FIXME: both is not implemented yet
}

REC_COLLECTION_CONTENTS_FILE = "collections_rec_level.txt" # file with the collections to be patched in when writing from LCIO to EDM4hep

from Configurables import GeoSvc, TrackingCellIDEncodingSvc, Lcio2EDM4hepTool
from Configurables import GeoSvc, TrackingCellIDEncodingSvc
geoservice = GeoSvc("GeoSvc")
geoservice.detectors = [reco_args.compactFile]
geoservice.OutputLevel = INFO
Expand Down Expand Up @@ -92,14 +99,6 @@
},
)

if reco_args.inputFiles:
read = create_reader(reco_args.inputFiles, evtsvc)
read.OutputLevel = INFO
algList.append(read)
else:
print('WARNING: No input files specified, the CLD Reconstruction will fail')
read = None

MyAIDAProcessor = MarlinProcessorWrapper("MyAIDAProcessor")
MyAIDAProcessor.OutputLevel = WARNING
MyAIDAProcessor.ProcessorType = "AIDAProcessor"
Expand Down Expand Up @@ -144,6 +143,24 @@
sequenceLoader.load("HighLevelReco/PFOSelector")
sequenceLoader.load("HighLevelReco/JetClusteringOrRenaming")
sequenceLoader.load("HighLevelReco/JetAndVertex")

# jet-flavor tagging
if not reco_args.trackingOnly and reco_args.enableMLJetTagger:
# convert all lcio collections to edm4hep - tagger expects edm4hep collections

# Make sure that all collections are always available by patching in missing ones on-the-fly
collPatcher_4tagging = MarlinProcessorWrapper(
"CollPatcher_4tagging", OutputLevel=INFO, ProcessorType="PatchCollections"
)
collPatcher_4tagging.Parameters = {
"PatchCollections": parse_collection_patch_file(REC_COLLECTION_CONTENTS_FILE)
}
algList.append(collPatcher_4tagging)
# actual conversion
attach_lcio2edm4hep_conversion_for_tagging(algList)
# add the tagger
sequenceLoader.load("HighLevelReco/MLJetTagger")

# event number processor, down here to attach the conversion back to edm4hep to it
algList.append(EventNumber)

Expand All @@ -169,21 +186,23 @@
}
algList.append(collPatcherRec)

Output_REC = create_writer("edm4hep", "Output_REC", f"{reco_args.outputBasename}_REC")
algList.append(Output_REC)
# keep all collections
io_svc.outputCommands = ["keep *"]

# FIXME: add option to write only selected collections with SVC

# FIXME: needs https://github.com/key4hep/k4FWCore/issues/226
# Output_DST = create_writer("edm4hep", "Output_DST", f"{reco_args.outputBasename}_DST", DST_KEEPLIST)
# algList.append(Output_DST)


# We need to convert the inputs in case we have EDM4hep input
attach_edm4hep2lcio_conversion(algList, read)
attach_edm4hep2lcio_conversion(algList) # , read)

# We need to convert the outputs in case we have EDM4hep output
attach_lcio2edm4hep_conversion(algList)

from Configurables import ApplicationMgr

ApplicationMgr( TopAlg = algList,
EvtSel = 'NONE',
EvtMax = 3, # Overridden by the --num-events switch to k4run
Expand Down
72 changes: 72 additions & 0 deletions CLDConfig/HighLevelReco/MLJetTagger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# Copyright (c) 2014-2024 Key4hep-Project.
#
# This file is part of Key4hep.
# See https://key4hep.github.io/key4hep-doc/ for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from Gaudi.Configuration import WARNING, INFO, DEBUG
from Configurables import JetTagger, Lcio2EDM4hepTool, MarlinProcessorWrapper
import yaml
import os

if reco_args.enableMLJetTagger:
# check if jet clustering is also enabled (prerequisite for jet flavor tagging)
if not reco_args.enableLCFIJet:
raise ValueError("MLJetTagger requires LCFIPlus jet clustering to be enabled. Please add --enableLCFIJet to the command or disable --enableMLJetTagger.")

# Get the directory of the current script
script_dir = os.path.dirname(os.path.abspath(__file__))
# Construct the path to the YAML file
yaml_path = os.path.join(script_dir, "models_MLJetTagger.yaml")

# Load YAML config
with open(yaml_path, "r") as file:
model_config = yaml.safe_load(file)

# check if the model type is valid
if reco_args.MLJetTaggerModel not in model_config:
raise ValueError(f"Invalid model type '{reco_args.MLJetTaggerModel}'. Valid options are: {', '.join(model_config.keys())}.")

# load the model configuration
onnx_model = model_config[reco_args.MLJetTaggerModel]["onnx_model"]
json_onnx_config = model_config[reco_args.MLJetTaggerModel]["json_onnx_config"]
flavor_collection_names = model_config[reco_args.MLJetTaggerModel]["flavor_collection_names"]

# print out the model configuration
print("RUNNING JET TAGGING WITH MLJETTAGGER")

print(f"Using MLJetTagger model: \t\t {reco_args.MLJetTaggerModel}\n",
f"The model uses the architecture: \t {model_config[reco_args.MLJetTaggerModel]['model']}\n",
f"was trained on the kinematics: \t {model_config[reco_args.MLJetTaggerModel]['kinematics']}\n",
f"and the detector version: \t\t {model_config[reco_args.MLJetTaggerModel]['detector']}\n",
f"at a center-of-mass energy of: \t {model_config[reco_args.MLJetTaggerModel]['ecm']} GeV\n",
f"Comment: \t\t\t\t {model_config[reco_args.MLJetTaggerModel]['comment']}\n",
f"Appending collections to the event: \t {', '.join(flavor_collection_names)}\n",)

# create the MLJetTagger algorithm
k4MLJetTagger = JetTagger("JetTagger",
model_path=onnx_model,
json_path=json_onnx_config,
flavor_collection_names = flavor_collection_names, # to make sure the order and nameing is correct
InputJets=["RefinedVertexJets"],
InputPrimaryVertices=["PrimaryVertices"],
OutputIDCollections=flavor_collection_names,
OutputLevel=DEBUG,
)

# append sequence to the algorithm list
MLJetTaggerSequence = [
k4MLJetTagger,
]
18 changes: 18 additions & 0 deletions CLDConfig/HighLevelReco/models_MLJetTagger.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# this yaml file stores and should be filled in with information about how a jet-flavor tagger is trained and the necessary information to run inference

model_ParT_ecm240_cld_o2_v5:
model: "ParticleTransformer"
ecm: 240
detector: "CLD_o2_v5"
kinematics: "Z(vv)H(jj)"
onnx_model: "/eos/experiment/fcc/ee/jet_flavour_tagging/fullsim_test_spring2024/fullsimCLD240_2mio.onnx"
json_onnx_config: "/eos/experiment/fcc/ee/jet_flavour_tagging/fullsim_test_spring2024/preprocess_fullsimCLD240_2mio.json"
flavor_collection_names:
- "RefinedJetTag_G"
- "RefinedJetTag_U"
- "RefinedJetTag_S"
- "RefinedJetTag_C"
- "RefinedJetTag_B"
- "RefinedJetTag_D"
- "RefinedJetTag_TAU"
comment: "The model was trained on 1.9 mio/jets per flavor. First implementation of ML tagging for full sim."
Empty file added CLDConfig/cdb.log
Empty file.
56 changes: 54 additions & 2 deletions CLDConfig/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import importlib.util
import importlib.abc
from importlib.machinery import SourceFileLoader
from Configurables import PodioOutput, MarlinProcessorWrapper
from Configurables import PodioOutput, MarlinProcessorWrapper, EDM4hep2LcioTool
from typing import Iterable
from Gaudi.Configuration import WARNING
from Gaudi.Configuration import WARNING, DEBUG


def import_from(
Expand Down Expand Up @@ -132,7 +132,42 @@ def load(self, sequence: str) -> None:
seq = getattr(seq_module, seq_name)
self.alg_list.extend(seq)

def create_reader(input_files, evtSvc):
# FIXME: from https://github.com/key4hep/k4MarlinWrapper/blob/main/k4MarlinWrapper/python/k4MarlinWrapper/inputReader.py#L24-L40 but adapt it to IOSvc
"""Create the appropriate reader for the input files"""
if input_files[0].endswith(".slcio"):
if any(not f.endswith(".slcio") for f in input_files):
print("All input files need to have the same format (LCIO)")
sys.exit(1)

read = LcioEvent()
read.Files = input_files
else:
if any(not f.endswith(".root") for f in input_files):
print("All input files need to have the same format (EDM4hep)")
sys.exit(1)
read = PodioInput("PodioInput")
evtSvc.inputs = input_files

return read


def attach_edm4hep2lcio_conversion(algList):
"""Attach the edm4hep to lcio conversion if necessary e.g. when using create_reader. Should only be run after algList is complete."""
# if not isinstance(read, PodioInput):
# # nothing to convert :)
# return

# find first wrapper
for alg in algList:
if isinstance(alg, MarlinProcessorWrapper):
break

EDM4hep2LcioInput = EDM4hep2LcioTool("InputConversion")
EDM4hep2LcioInput.convertAll = True
# Adjust for the different naming conventions
EDM4hep2LcioInput.collNameMapping = {"MCParticles": "MCParticle"}
alg.EDM4hep2LcioTool = EDM4hep2LcioInput

def attach_lcio2edm4hep_conversion(algList: list) -> None:
"""Attaches a conversion from lcio to edm4hep at the last MarlinWrapper in algList if necessary
Expand All @@ -155,6 +190,23 @@ def attach_lcio2edm4hep_conversion(algList: list) -> None:

alg.Lcio2EDM4hepTool = lcioConvTool

def attach_lcio2edm4hep_conversion_for_tagging(algList: list) -> None:
"""Attaches a conversion from lcio to edm4hep at the last MarlinWrapper in algList just before tagging, as the tagger expect edm4hep collections
"""
# find last marlin wrapper
for alg in reversed(algList):
if isinstance(alg, MarlinProcessorWrapper):
break

from Configurables import Lcio2EDM4hepTool
lcioConvTool_4tagging = Lcio2EDM4hepTool("lcio2EDM4hep")
lcioConvTool_4tagging.convertAll = True
lcioConvTool_4tagging.collNameMapping = {
"MCParticle": "MCParticles",
}

alg.Lcio2EDM4hepTool = lcioConvTool_4tagging



def _create_writer_lcio(writer_name: str, output_name: str, keep_list: Iterable = (), full_subset_list: Iterable = ()):
Expand Down
Loading