Skip to content

Commit a3d90d6

Browse files
committed
make grouping aware of the origin change; fix exception in treeName generator; restrict Preslice to tables
1 parent 90722db commit a3d90d6

9 files changed

Lines changed: 67 additions & 41 deletions

File tree

Framework/Core/include/Framework/ASoA.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,7 +1547,7 @@ struct PreslicePolicyGeneral : public PreslicePolicyBase {
15471547
template <typename T>
15481548
concept is_preslice_policy = std::derived_from<T, PreslicePolicyBase>;
15491549

1550-
template <typename T, is_preslice_policy Policy, bool OPT = false>
1550+
template <soa::is_table T, is_preslice_policy Policy, bool OPT = false>
15511551
struct PresliceBase : public Policy {
15521552
constexpr static bool optional = OPT;
15531553
using target_t = T;
@@ -1580,13 +1580,13 @@ struct PresliceBase : public Policy {
15801580
}
15811581
};
15821582

1583-
template <typename T>
1583+
template <soa::is_table T>
15841584
using PresliceUnsorted = PresliceBase<T, PreslicePolicyGeneral, false>;
1585-
template <typename T>
1585+
template <soa::is_table T>
15861586
using PresliceUnsortedOptional = PresliceBase<T, PreslicePolicyGeneral, true>;
1587-
template <typename T>
1587+
template <soa::is_table T>
15881588
using Preslice = PresliceBase<T, PreslicePolicySorted, false>;
1589-
template <typename T>
1589+
template <soa::is_table T>
15901590
using PresliceOptional = PresliceBase<T, PreslicePolicySorted, true>;
15911591

15921592
template <typename T>
@@ -1744,7 +1744,12 @@ auto doFilteredSliceBy(T const* table, o2::framework::PresliceBase<C, framework:
17441744
template <soa::is_table T>
17451745
auto doSliceByCached(T const* table, framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache)
17461746
{
1747-
auto localCache = cache.ptr->getCacheFor({"", o2::soa::getMatcherFromTypeForKey<T>(node.name), node.name});
1747+
auto localCache = cache.ptr->getCacheFor({"", [&o = cache.ptr->newOrigin](framework::ConcreteDataMatcher&& m){
1748+
if ((m.origin == header::DataOrigin{"AOD"}) && (o != header::DataOrigin{"AOD"})) {
1749+
m.origin = o;
1750+
}
1751+
return m;
1752+
}(o2::soa::getMatcherFromTypeForKey<T>(node.name)), node.name});
17481753
auto [offset, count] = localCache.getSliceFor(value);
17491754
auto t = typename T::self_t({table->asArrowTable()->Slice(static_cast<uint64_t>(offset), count)}, static_cast<uint64_t>(offset));
17501755
if (t.tableSize() != 0) {
@@ -1756,7 +1761,12 @@ auto doSliceByCached(T const* table, framework::expressions::BindingNode const&
17561761
template <soa::is_filtered_table T>
17571762
auto doFilteredSliceByCached(T const* table, framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache)
17581763
{
1759-
auto localCache = cache.ptr->getCacheFor({"", o2::soa::getMatcherFromTypeForKey<T>(node.name), node.name});
1764+
auto localCache = cache.ptr->getCacheFor({"", [&o = cache.ptr->newOrigin](framework::ConcreteDataMatcher&& m){
1765+
if ((m.origin == header::DataOrigin{"AOD"}) && (o != header::DataOrigin{"AOD"})) {
1766+
m.origin = o;
1767+
}
1768+
return m;
1769+
}(o2::soa::getMatcherFromTypeForKey<T>(node.name)), node.name});
17601770
auto [offset, count] = localCache.getSliceFor(value);
17611771
auto slice = table->asArrowTable()->Slice(static_cast<uint64_t>(offset), count);
17621772
return prepareFilteredSlice(table, slice, offset);
@@ -1765,7 +1775,12 @@ auto doFilteredSliceByCached(T const* table, framework::expressions::BindingNode
17651775
template <soa::is_table T>
17661776
auto doSliceByCachedUnsorted(T const* table, framework::expressions::BindingNode const& node, int value, o2::framework::SliceCache& cache)
17671777
{
1768-
auto localCache = cache.ptr->getCacheUnsortedFor({"", o2::soa::getMatcherFromTypeForKey<T>(node.name), node.name});
1778+
auto localCache = cache.ptr->getCacheUnsortedFor({"", [&o = cache.ptr->newOrigin](framework::ConcreteDataMatcher&& m){
1779+
if ((m.origin == header::DataOrigin{"AOD"}) && (o != header::DataOrigin{"AOD"})) {
1780+
m.origin = o;
1781+
}
1782+
return m;
1783+
}(o2::soa::getMatcherFromTypeForKey<T>(node.name)), node.name});
17691784
if constexpr (soa::is_filtered_table<T>) {
17701785
auto t = typename T::self_t({table->asArrowTable()}, localCache.getSliceFor(value));
17711786
if (t.tableSize() != 0) {

Framework/Core/include/Framework/AnalysisHelpers.h

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,9 @@ constexpr auto tableRef2OutputSpec(header::DataOrigin newOrigin = header::DataOr
452452
} else if constexpr (soa::with_index_pack<md>) {
453453
metadata.emplace_back("index-records", framework::VariantType::Bool, true, framework::ConfigParamSpec::HelpString{"\"\""});
454454
}
455+
if ((R.origin_hash == "AOD"_h) && (newOrigin != header::DataOrigin{"AOD"})) {
456+
metadata.push_back(framework::ConfigParamSpec{"aod-origin-replaced", framework::VariantType::Bool, true, {"\"\""}});
457+
}
455458
return framework::OutputSpec{
456459
framework::OutputLabel{o2::aod::label<R>()},
457460
((R.origin_hash == "AOD"_h) && (newOrigin != header::DataOrigin{"AOD"})) ? newOrigin : o2::aod::origin<R>(),
@@ -461,15 +464,6 @@ constexpr auto tableRef2OutputSpec(header::DataOrigin newOrigin = header::DataOr
461464
metadata};
462465
}
463466

464-
template <TableRef R>
465-
constexpr auto tableRef2Output()
466-
{
467-
return framework::Output{
468-
o2::aod::origin<R>(),
469-
o2::aod::description(o2::aod::signature<R>()),
470-
R.version};
471-
}
472-
473467
template <TableRef R>
474468
constexpr auto tableRef2OutputRef()
475469
{
@@ -498,9 +492,9 @@ struct WritingCursor {
498492
using persistent_table_t = decltype([]() { if constexpr (soa::is_iterator<T>) { return typename T::parent_t{nullptr}; } else { return T{nullptr}; } }());
499493
using cursor_t = decltype(std::declval<TableBuilder>().cursor<persistent_table_t>());
500494
OutputSpec outputSpec{soa::tableRef2OutputSpec<persistent_table_t::ref>()};
501-
OutputSpec updateOutputSpec(header::DataOrigin const& newOrigin)
495+
OutputSpec updateOutputSpec(header::DataOrigin const& newOrigin = header::DataOrigin{"AOD"})
502496
{
503-
outputSpec = soa::tableRef2OutputSpec<persistent_table_t::ref>(newOrigin);
497+
return soa::tableRef2OutputSpec<persistent_table_t::ref>(newOrigin);
504498
}
505499

506500
template <typename... Ts>

Framework/Core/include/Framework/AnalysisManagers.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ bool updateOutputSpec(T& entity, header::DataOrigin newOrigin = header::DataOrig
190190
return true;
191191
}
192192

193+
template <is_produces_group T>
194+
bool updateOutputSpec(T& producesGroup, header::DataOrigin newOrigin = header::DataOrigin{"AOD"})
195+
{
196+
homogeneous_apply_refs<true>([&newOrigin](auto& produces){ return updateOutputSpec(produces, newOrigin); }, producesGroup);
197+
}
198+
193199
template <typename C>
194200
bool newDataframeCondition(InputRecord&, C&)
195201
{
@@ -604,9 +610,9 @@ bool replaceOrigin(T&, header::DataOrigin const&)
604610
}
605611

606612
template <is_preslice T>
607-
bool replaceOrigin(T& preslice, header::DataOrigin const& newOrigin)
613+
bool replaceOrigin(T& preslice, header::DataOrigin const& newOrigin = header::DataOrigin{"AOD"})
608614
{
609-
if ((T::target_t::originals[0].origin_hash == "AOD"_h) && (newOrigin != header::DataOrigin{"AOD"})) {
615+
if ((T::target_t::binding_origin == "AOD"_h) && (newOrigin != header::DataOrigin{"AOD"})) {
610616
preslice.bindingKey.matcher = framework::replaceOrigin(preslice.bindingKey.matcher, newOrigin);
611617
return true;
612618
}

Framework/Core/include/Framework/AnalysisTask.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ struct AnalysisDataProcessorBuilder {
315315
}
316316

317317
template <typename Task, is_table_iterator_or_enumeration Grouping, soa::is_table... Associated>
318-
static void invokeProcess(Task& task, InputRecord& inputs, std::vector<InputInfo> iInfos, void (Task::*processingFunction)(Grouping, Associated...), std::vector<ExpressionInfo>& infos, ArrowTableSlicingCache& slices, std::string const& newOriginStr)
318+
static void invokeProcess(Task& task, InputRecord& inputs, std::vector<InputInfo> iInfos, void (Task::*processingFunction)(Grouping, Associated...), std::vector<ExpressionInfo>& infos, ArrowTableSlicingCache& slices, header::DataOrigin newOrigin = header::DataOrigin{"AOD"})
319319
{
320320
using G = std::decay_t<Grouping>;
321321
auto groupingTable = AnalysisDataProcessorBuilder::bindGroupingTable(inputs, iInfos, processingFunction, infos);
@@ -387,7 +387,7 @@ struct AnalysisDataProcessorBuilder {
387387
task);
388388
overwriteInternalIndices(associatedTables, associatedTables);
389389
if constexpr (soa::is_iterator<G>) {
390-
auto slicer = GroupSlicer(groupingTable, associatedTables, slices, newOriginStr);
390+
auto slicer = GroupSlicer(groupingTable, associatedTables, slices, newOrigin);
391391
for (auto& slice : slicer) {
392392
auto associatedSlices = slice.associatedTables();
393393
overwriteInternalIndices(associatedSlices, associatedTables);
@@ -652,8 +652,9 @@ DataProcessorSpec adaptAnalysisTask(ConfigContext const& ctx, Args&&... args)
652652

653653
ic.services().get<ArrowTableSlicingCacheDef>().setCaches(std::move(bindingsKeys));
654654
ic.services().get<ArrowTableSlicingCacheDef>().setCachesUnsorted(std::move(bindingsKeysUnsorted));
655+
ic.services().get<ArrowTableSlicingCacheDef>().setOrigin(newOrigin);
655656

656-
return [task, expressionInfos, inputInfos, newOriginStr](ProcessingContext& pc) mutable {
657+
return [task, expressionInfos, inputInfos, newOrigin](ProcessingContext& pc) mutable {
657658
// load the ccdb object from their cache
658659
homogeneous_apply_refs_sized<numElements>([&pc](auto& element) { return analysis_task_parsers::newDataframeCondition(pc.inputs(), element); }, *task.get());
659660
// reset partitions once per dataframe
@@ -676,14 +677,14 @@ DataProcessorSpec adaptAnalysisTask(ConfigContext const& ctx, Args&&... args)
676677
}
677678
// execute process()
678679
if constexpr (requires { &T::process; }) {
679-
AnalysisDataProcessorBuilder::invokeProcess(*(task.get()), pc.inputs(), inputInfos, &T::process, expressionInfos, slices, newOriginStr);
680+
AnalysisDataProcessorBuilder::invokeProcess(*(task.get()), pc.inputs(), inputInfos, &T::process, expressionInfos, slices, newOrigin);
680681
}
681682
// execute optional process()
682683
homogeneous_apply_refs_sized<numElements>(
683-
[&pc, &expressionInfos, &task, &slices, &inputInfos, &newOriginStr](auto& x) {
684+
[&pc, &expressionInfos, &task, &slices, &inputInfos, &newOrigin](auto& x) {
684685
if constexpr (is_process_configurable<decltype(x)>) {
685686
if (x.value == true) {
686-
AnalysisDataProcessorBuilder::invokeProcess(*task.get(), pc.inputs(), inputInfos, x.process, expressionInfos, slices, newOriginStr);
687+
AnalysisDataProcessorBuilder::invokeProcess(*task.get(), pc.inputs(), inputInfos, x.process, expressionInfos, slices, newOrigin);
687688
return true;
688689
}
689690
return false;

Framework/Core/include/Framework/ArrowTableSlicingCache.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,14 @@ struct ArrowTableSlicingCacheDef {
6464
constexpr static ServiceKind service_kind = ServiceKind::Global;
6565
Cache bindingsKeys;
6666
Cache bindingsKeysUnsorted;
67+
header::DataOrigin newOrigin = header::DataOrigin{"AOD"};
6768

6869
void setCaches(Cache&& bsks);
6970
void setCachesUnsorted(Cache&& bsks);
71+
void setOrigin(header::DataOrigin newOrigin_ = header::DataOrigin{"AOD"})
72+
{
73+
newOrigin = newOrigin_;
74+
}
7075
};
7176

7277
struct ArrowTableSlicingCache {
@@ -80,7 +85,9 @@ struct ArrowTableSlicingCache {
8085
std::vector<std::vector<int>> valuesUnsorted;
8186
std::vector<ListVector> groups;
8287

83-
ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorted = {});
88+
header::DataOrigin newOrigin = header::DataOrigin{"AOD"};
89+
90+
ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorted = {}, header::DataOrigin newOrigin_ = header::DataOrigin{"AOD"});
8491

8592
// set caching information externally
8693
void setCaches(Cache&& bsks, Cache&& bsksUnsorted = {});

Framework/Core/include/Framework/GroupSlicer.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ namespace o2::framework
2626
template <typename G, typename... A>
2727
struct GroupSlicer {
2828
using grouping_t = std::decay_t<G>;
29-
GroupSlicer(G& gt, std::tuple<A...>& at, ArrowTableSlicingCache& slices, std::string const& newOriginStr = "AOD")
29+
GroupSlicer(G& gt, std::tuple<A...>& at, ArrowTableSlicingCache& slices, header::DataOrigin newOrigin = header::DataOrigin{"AOD"})
3030
: max{gt.size()},
31-
mBegin{GroupSlicerIterator(gt, at, slices, newOriginStr)}
31+
mBegin{GroupSlicerIterator(gt, at, slices, newOrigin)}
3232
{
3333
}
3434

@@ -87,15 +87,15 @@ struct GroupSlicer {
8787
starts[index] = selections[index]->begin();
8888
}
8989

90-
GroupSlicerIterator(G& gt, std::tuple<A...>& at, ArrowTableSlicingCache& slices, std::string const& newOriginStr = "AOD")
90+
GroupSlicerIterator(G& gt, std::tuple<A...>& at, ArrowTableSlicingCache& slices, header::DataOrigin newOrigin = header::DataOrigin{"AOD"})
9191
: mIndexColumnName{std::string("fIndex") + o2::framework::cutString(o2::soa::getLabelFromType<G>())},
9292
mGt{&gt},
9393
mAt{&at},
9494
mGroupingElement{gt.begin()},
9595
position{0},
96-
mSlices{&slices}
96+
mSlices{&slices},
97+
replacementOrigin{newOrigin}
9798
{
98-
replacementOrigin.runtimeInit(newOriginStr.c_str(), newOriginStr.size());
9999
if constexpr (soa::is_filtered_table<std::decay_t<G>>) {
100100
groupSelection = mGt->getSelectedRows();
101101
}

Framework/Core/src/AnalysisDataModelHelpers.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ std::string getTreeName(header::DataHeader dh, bool wasAOD)
4242

4343
// exceptions from this
4444
auto origin = std::string(dh.dataOrigin.str);
45-
if (origin == "AOD" && description == "MCCOLLISLABEL") {
45+
if ((origin == "AOD" || wasAOD) && description == "MCCOLLISLABEL") {
4646
treeName = "O2mccollisionlabel";
4747
}
4848

Framework/Core/src/ArrowSupport.cxx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,8 @@ o2::framework::ServiceSpec ArrowSupport::arrowTableSlicingCacheSpec()
769769
.uniqueId = CommonServices::simpleServiceId<ArrowTableSlicingCache>(),
770770
.init = [](ServiceRegistryRef services, DeviceState&, fair::mq::ProgOptions&) { return ServiceHandle{TypeIdHelpers::uniqueId<ArrowTableSlicingCache>(),
771771
new ArrowTableSlicingCache(Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeys},
772-
Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeysUnsorted}),
772+
Cache{services.get<ArrowTableSlicingCacheDef>().bindingsKeysUnsorted},
773+
services.get<ArrowTableSlicingCacheDef>().newOrigin),
773774
ServiceKind::Stream, typeid(ArrowTableSlicingCache).name()}; },
774775
.configure = CommonServices::noConfiguration(),
775776
.preProcessing = [](ProcessingContext& pc, void* service_ptr) {

Framework/Core/src/ArrowTableSlicingCache.cxx

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "Framework/ArrowTableSlicingCache.h"
1313
#include "Framework/RuntimeError.h"
14+
#include "Framework/DataSpecUtils.h"
1415

1516
#include <arrow/compute/api_aggregate.h>
1617
#include <arrow/compute/kernel.h>
@@ -78,9 +79,10 @@ void ArrowTableSlicingCacheDef::setCachesUnsorted(Cache&& bsks)
7879
bindingsKeysUnsorted = bsks;
7980
}
8081

81-
ArrowTableSlicingCache::ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorted)
82+
ArrowTableSlicingCache::ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorted, header::DataOrigin newOrigin_)
8283
: bindingsKeys{bsks},
83-
bindingsKeysUnsorted{bsksUnsorted}
84+
bindingsKeysUnsorted{bsksUnsorted},
85+
newOrigin{newOrigin_}
8486
{
8587
offsets.resize(bindingsKeys.size());
8688
sizes.resize(bindingsKeys.size());
@@ -112,7 +114,7 @@ arrow::Status ArrowTableSlicingCache::updateCacheEntry(int pos, std::shared_ptr<
112114
}
113115
auto& [b, m, k, e] = bindingsKeys[pos];
114116
if (!e) {
115-
throw runtime_error_f("Disabled cache %s/%s update requested", b.c_str(), k.c_str());
117+
throw runtime_error_f("Disabled cache (%s) %s/%s update requested", DataSpecUtils::describe(m).c_str(), b.c_str(), k.c_str());
116118
}
117119
validateOrder(bindingsKeys[pos], table);
118120

@@ -205,7 +207,7 @@ std::pair<int, bool> ArrowTableSlicingCache::getCachePos(const Entry& bindingKey
205207
if (pos != -1) {
206208
return {pos, false};
207209
}
208-
throw runtime_error_f("%s/%s not found neither in sorted or unsorted cache", bindingKey.binding.c_str(), bindingKey.key.c_str());
210+
throw runtime_error_f("(%s) %s/%s not found neither in sorted or unsorted cache", DataSpecUtils::describe(bindingKey.matcher).c_str(), bindingKey.binding.c_str(), bindingKey.key.c_str());
209211
}
210212

211213
int ArrowTableSlicingCache::getCachePosSortedFor(Entry const& bindingKey) const
@@ -242,10 +244,10 @@ SliceInfoUnsortedPtr ArrowTableSlicingCache::getCacheUnsortedFor(const Entry& bi
242244
{
243245
auto [p, s] = getCachePos(bindingKey);
244246
if (s) {
245-
throw runtime_error_f("%s/%s is found in sorted cache", bindingKey.binding.c_str(), bindingKey.key.c_str());
247+
throw runtime_error_f("(%s) %s/%s is found in sorted cache", DataSpecUtils::describe(bindingKey.matcher).c_str(), bindingKey.binding.c_str(), bindingKey.key.c_str());
246248
}
247249
if (!bindingsKeysUnsorted[p].enabled) {
248-
throw runtime_error_f("Disabled unsorted cache %s/%s is requested", bindingKey.binding.c_str(), bindingKey.key.c_str());
250+
throw runtime_error_f("Disabled unsorted cache (%s) %s/%s is requested", DataSpecUtils::describe(bindingKey.matcher).c_str(), bindingKey.binding.c_str(), bindingKey.key.c_str());
249251
}
250252

251253
return getCacheUnsortedForPos(p);

0 commit comments

Comments
 (0)