diff --git a/CLDConfig/CLDReconstruction.py b/CLDConfig/CLDReconstruction.py index 4683c4c..527d201 100644 --- a/CLDConfig/CLDReconstruction.py +++ b/CLDConfig/CLDReconstruction.py @@ -19,10 +19,10 @@ import os from Gaudi.Configuration import INFO, WARNING, DEBUG -from Configurables import k4DataSvc, MarlinProcessorWrapper -from k4MarlinWrapper.inputReader import create_reader, attach_edm4hep2lcio_conversion +from Configurables import MarlinProcessorWrapper, Lcio2EDM4hepTool from k4FWCore.parseArgs import parser -from py_utils import SequenceLoader, attach_lcio2edm4hep_conversion, create_writer, parse_collection_patch_file +from k4FWCore import ApplicationMgr, IOSvc +from py_utils import SequenceLoader, attach_lcio2edm4hep_conversion, attach_edm4hep2lcio_conversion, create_writer, parse_collection_patch_file, attach_lcio2edm4hep_conversion_for_tagging import ROOT ROOT.gROOT.SetBatch(True) @@ -33,6 +33,8 @@ parser_group.add_argument("--outputBasename", help="Basename of the output file(s)", default="output") parser_group.add_argument("--trackingOnly", action="store_true", help="Run only track reconstruction", default=False) parser_group.add_argument("--enableLCFIJet", action="store_true", help="Enable LCFIPlus jet clustering parts", default=False) +parser_group.add_argument("--enableMLJetTagger", action="store_true", help="Enable ML-based jet flavor tagging", default=False) +parser_group.add_argument("--MLJetTaggerModel", action="store", help="Type of ML model to use for inference", type=str, default="model_ParT_ecm240_cld_o2_v5") parser_group.add_argument("--cms", action="store", help="Choose a Centre-of-Mass energy", default=240, choices=(91, 160, 240, 365), type=int) parser_group.add_argument("--compactFile", help="Compact detector file to use", type=str, default=os.environ["K4GEO"] + "/FCCee/CLD/compact/CLD_o2_v07/CLD_o2_v07.xml") tracking_group = parser_group.add_mutually_exclusive_group() @@ -43,8 +45,14 @@ algList = [] svcList = [] -evtsvc = k4DataSvc("EventDataSvc") -svcList.append(evtsvc) +if not reco_args.inputFiles: + print('WARNING: No input files specified, the CLD Reconstruction will fail') + reco_args.inputFiles = [] + +io_svc = IOSvc("IOSvc") +io_svc.Input = reco_args.inputFiles +io_svc.Output = f"{reco_args.outputBasename}.edm4hep.root" +svcList.append(io_svc) CONFIG = { "CalorimeterIntegrationTimeWindow": "10ns", @@ -56,10 +64,9 @@ "OutputMode": "EDM4Hep", "OutputModeChoices": ["LCIO", "EDM4hep"] #, "both"] FIXME: both is not implemented yet } - REC_COLLECTION_CONTENTS_FILE = "collections_rec_level.txt" # file with the collections to be patched in when writing from LCIO to EDM4hep -from Configurables import GeoSvc, TrackingCellIDEncodingSvc, Lcio2EDM4hepTool +from Configurables import GeoSvc, TrackingCellIDEncodingSvc geoservice = GeoSvc("GeoSvc") geoservice.detectors = [reco_args.compactFile] geoservice.OutputLevel = INFO @@ -92,14 +99,6 @@ }, ) -if reco_args.inputFiles: - read = create_reader(reco_args.inputFiles, evtsvc) - read.OutputLevel = INFO - algList.append(read) -else: - print('WARNING: No input files specified, the CLD Reconstruction will fail') - read = None - MyAIDAProcessor = MarlinProcessorWrapper("MyAIDAProcessor") MyAIDAProcessor.OutputLevel = WARNING MyAIDAProcessor.ProcessorType = "AIDAProcessor" @@ -144,6 +143,24 @@ sequenceLoader.load("HighLevelReco/PFOSelector") sequenceLoader.load("HighLevelReco/JetClusteringOrRenaming") sequenceLoader.load("HighLevelReco/JetAndVertex") + +# jet-flavor tagging +if not reco_args.trackingOnly and reco_args.enableMLJetTagger: + # convert all lcio collections to edm4hep - tagger expects edm4hep collections + + # Make sure that all collections are always available by patching in missing ones on-the-fly + collPatcher_4tagging = MarlinProcessorWrapper( + "CollPatcher_4tagging", OutputLevel=INFO, ProcessorType="PatchCollections" + ) + collPatcher_4tagging.Parameters = { + "PatchCollections": parse_collection_patch_file(REC_COLLECTION_CONTENTS_FILE) + } + algList.append(collPatcher_4tagging) + # actual conversion + attach_lcio2edm4hep_conversion_for_tagging(algList) + # add the tagger + sequenceLoader.load("HighLevelReco/MLJetTagger") + # event number processor, down here to attach the conversion back to edm4hep to it algList.append(EventNumber) @@ -169,8 +186,10 @@ } algList.append(collPatcherRec) - Output_REC = create_writer("edm4hep", "Output_REC", f"{reco_args.outputBasename}_REC") - algList.append(Output_REC) + # keep all collections + io_svc.outputCommands = ["keep *"] + + # FIXME: add option to write only selected collections with SVC # FIXME: needs https://github.com/key4hep/k4FWCore/issues/226 # Output_DST = create_writer("edm4hep", "Output_DST", f"{reco_args.outputBasename}_DST", DST_KEEPLIST) @@ -178,12 +197,12 @@ # We need to convert the inputs in case we have EDM4hep input -attach_edm4hep2lcio_conversion(algList, read) +attach_edm4hep2lcio_conversion(algList) # , read) # We need to convert the outputs in case we have EDM4hep output attach_lcio2edm4hep_conversion(algList) -from Configurables import ApplicationMgr + ApplicationMgr( TopAlg = algList, EvtSel = 'NONE', EvtMax = 3, # Overridden by the --num-events switch to k4run diff --git a/CLDConfig/HighLevelReco/MLJetTagger.py b/CLDConfig/HighLevelReco/MLJetTagger.py new file mode 100644 index 0000000..096f494 --- /dev/null +++ b/CLDConfig/HighLevelReco/MLJetTagger.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2014-2024 Key4hep-Project. +# +# This file is part of Key4hep. +# See https://key4hep.github.io/key4hep-doc/ for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from Gaudi.Configuration import WARNING, INFO, DEBUG +from Configurables import JetTagger, Lcio2EDM4hepTool, MarlinProcessorWrapper +import yaml +import os + +if reco_args.enableMLJetTagger: + # check if jet clustering is also enabled (prerequisite for jet flavor tagging) + if not reco_args.enableLCFIJet: + raise ValueError("MLJetTagger requires LCFIPlus jet clustering to be enabled. Please add --enableLCFIJet to the command or disable --enableMLJetTagger.") + + # Get the directory of the current script + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Construct the path to the YAML file + yaml_path = os.path.join(script_dir, "models_MLJetTagger.yaml") + + # Load YAML config + with open(yaml_path, "r") as file: + model_config = yaml.safe_load(file) + + # check if the model type is valid + if reco_args.MLJetTaggerModel not in model_config: + raise ValueError(f"Invalid model type '{reco_args.MLJetTaggerModel}'. Valid options are: {', '.join(model_config.keys())}.") + + # load the model configuration + onnx_model = model_config[reco_args.MLJetTaggerModel]["onnx_model"] + json_onnx_config = model_config[reco_args.MLJetTaggerModel]["json_onnx_config"] + flavor_collection_names = model_config[reco_args.MLJetTaggerModel]["flavor_collection_names"] + + # print out the model configuration + print("RUNNING JET TAGGING WITH MLJETTAGGER") + + print(f"Using MLJetTagger model: \t\t {reco_args.MLJetTaggerModel}\n", + f"The model uses the architecture: \t {model_config[reco_args.MLJetTaggerModel]['model']}\n", + f"was trained on the kinematics: \t {model_config[reco_args.MLJetTaggerModel]['kinematics']}\n", + f"and the detector version: \t\t {model_config[reco_args.MLJetTaggerModel]['detector']}\n", + f"at a center-of-mass energy of: \t {model_config[reco_args.MLJetTaggerModel]['ecm']} GeV\n", + f"Comment: \t\t\t\t {model_config[reco_args.MLJetTaggerModel]['comment']}\n", + f"Appending collections to the event: \t {', '.join(flavor_collection_names)}\n",) + + # create the MLJetTagger algorithm + k4MLJetTagger = JetTagger("JetTagger", + model_path=onnx_model, + json_path=json_onnx_config, + flavor_collection_names = flavor_collection_names, # to make sure the order and nameing is correct + InputJets=["RefinedVertexJets"], + InputPrimaryVertices=["PrimaryVertices"], + OutputIDCollections=flavor_collection_names, + OutputLevel=DEBUG, + ) + + # append sequence to the algorithm list + MLJetTaggerSequence = [ + k4MLJetTagger, + ] diff --git a/CLDConfig/HighLevelReco/models_MLJetTagger.yaml b/CLDConfig/HighLevelReco/models_MLJetTagger.yaml new file mode 100644 index 0000000..1e70feb --- /dev/null +++ b/CLDConfig/HighLevelReco/models_MLJetTagger.yaml @@ -0,0 +1,18 @@ +# this yaml file stores and should be filled in with information about how a jet-flavor tagger is trained and the necessary information to run inference + +model_ParT_ecm240_cld_o2_v5: + model: "ParticleTransformer" + ecm: 240 + detector: "CLD_o2_v5" + kinematics: "Z(vv)H(jj)" + onnx_model: "/eos/experiment/fcc/ee/jet_flavour_tagging/fullsim_test_spring2024/fullsimCLD240_2mio.onnx" + json_onnx_config: "/eos/experiment/fcc/ee/jet_flavour_tagging/fullsim_test_spring2024/preprocess_fullsimCLD240_2mio.json" + flavor_collection_names: + - "RefinedJetTag_G" + - "RefinedJetTag_U" + - "RefinedJetTag_S" + - "RefinedJetTag_C" + - "RefinedJetTag_B" + - "RefinedJetTag_D" + - "RefinedJetTag_TAU" + comment: "The model was trained on 1.9 mio/jets per flavor. First implementation of ML tagging for full sim." diff --git a/CLDConfig/cdb.log b/CLDConfig/cdb.log new file mode 100644 index 0000000..e69de29 diff --git a/CLDConfig/py_utils.py b/CLDConfig/py_utils.py index 1b2143a..bd659da 100644 --- a/CLDConfig/py_utils.py +++ b/CLDConfig/py_utils.py @@ -21,9 +21,9 @@ import importlib.util import importlib.abc from importlib.machinery import SourceFileLoader -from Configurables import PodioOutput, MarlinProcessorWrapper +from Configurables import PodioOutput, MarlinProcessorWrapper, EDM4hep2LcioTool from typing import Iterable -from Gaudi.Configuration import WARNING +from Gaudi.Configuration import WARNING, DEBUG def import_from( @@ -132,7 +132,42 @@ def load(self, sequence: str) -> None: seq = getattr(seq_module, seq_name) self.alg_list.extend(seq) +def create_reader(input_files, evtSvc): + # FIXME: from https://github.com/key4hep/k4MarlinWrapper/blob/main/k4MarlinWrapper/python/k4MarlinWrapper/inputReader.py#L24-L40 but adapt it to IOSvc + """Create the appropriate reader for the input files""" + if input_files[0].endswith(".slcio"): + if any(not f.endswith(".slcio") for f in input_files): + print("All input files need to have the same format (LCIO)") + sys.exit(1) + read = LcioEvent() + read.Files = input_files + else: + if any(not f.endswith(".root") for f in input_files): + print("All input files need to have the same format (EDM4hep)") + sys.exit(1) + read = PodioInput("PodioInput") + evtSvc.inputs = input_files + + return read + + +def attach_edm4hep2lcio_conversion(algList): + """Attach the edm4hep to lcio conversion if necessary e.g. when using create_reader. Should only be run after algList is complete.""" + # if not isinstance(read, PodioInput): + # # nothing to convert :) + # return + + # find first wrapper + for alg in algList: + if isinstance(alg, MarlinProcessorWrapper): + break + + EDM4hep2LcioInput = EDM4hep2LcioTool("InputConversion") + EDM4hep2LcioInput.convertAll = True + # Adjust for the different naming conventions + EDM4hep2LcioInput.collNameMapping = {"MCParticles": "MCParticle"} + alg.EDM4hep2LcioTool = EDM4hep2LcioInput def attach_lcio2edm4hep_conversion(algList: list) -> None: """Attaches a conversion from lcio to edm4hep at the last MarlinWrapper in algList if necessary @@ -155,6 +190,23 @@ def attach_lcio2edm4hep_conversion(algList: list) -> None: alg.Lcio2EDM4hepTool = lcioConvTool +def attach_lcio2edm4hep_conversion_for_tagging(algList: list) -> None: + """Attaches a conversion from lcio to edm4hep at the last MarlinWrapper in algList just before tagging, as the tagger expect edm4hep collections + """ + # find last marlin wrapper + for alg in reversed(algList): + if isinstance(alg, MarlinProcessorWrapper): + break + + from Configurables import Lcio2EDM4hepTool + lcioConvTool_4tagging = Lcio2EDM4hepTool("lcio2EDM4hep") + lcioConvTool_4tagging.convertAll = True + lcioConvTool_4tagging.collNameMapping = { + "MCParticle": "MCParticles", + } + + alg.Lcio2EDM4hepTool = lcioConvTool_4tagging + def _create_writer_lcio(writer_name: str, output_name: str, keep_list: Iterable = (), full_subset_list: Iterable = ()):