Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Libs/Optimize/Optimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ int Optimize::SetParameters() {
this->ReadPrefixTransformFile(m_prefix_transform_file);
}

// Apply stored Procrustes transforms (e.g. for fixed shapes loaded from project)
for (auto& [domain_index, transform] : m_procrustes_transforms) {
m_sampler->GetParticleSystem()->SetTransform(domain_index, transform);
}

return true;
}

Expand Down Expand Up @@ -1811,6 +1816,11 @@ void Optimize::SetFixedDomains(std::vector<int> flags) {
this->m_domain_flags = flags;
}

//---------------------------------------------------------------------------
void Optimize::SetProcustesTransforms(std::map<int, vnl_matrix_fixed<double, 4, 4>> transforms) {
m_procrustes_transforms = std::move(transforms);
}

//---------------------------------------------------------------------------
const std::vector<int>& Optimize::GetDomainFlags() { return this->m_domain_flags; }

Expand Down
5 changes: 5 additions & 0 deletions Libs/Optimize/Optimize.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#endif

// std
#include <map>
#include <string>
#include <vector>

Expand Down Expand Up @@ -236,6 +237,9 @@ class Optimize {
//! Set Domain Flags (TODO: details)
void SetFixedDomains(std::vector<int> flags);

//! Set Procrustes transforms to load for specific domains (applied after initialization)
void SetProcustesTransforms(std::map<int, vnl_matrix_fixed<double, 4, 4>> transforms);

//! Shared boundary settings
void SetSharedBoundaryEnabled(bool enabled);
void SetSharedBoundaryWeight(double weight);
Expand Down Expand Up @@ -415,6 +419,7 @@ class Optimize {
double m_cotan_sigma_factor = 5.0;
std::vector<int> m_particle_flags;
std::vector<int> m_domain_flags;
std::map<int, vnl_matrix_fixed<double, 4, 4>> m_procrustes_transforms; // domain index -> transform
double m_narrow_band = 0.0;
bool m_narrow_band_set = false;
bool m_fixed_domains_present = false;
Expand Down
28 changes: 27 additions & 1 deletion Libs/Optimize/OptimizeParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,30 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
SW_DEBUG("Setting Initial Points");
optimize->SetInitialPoints(get_initial_points());
}

// Store Procrustes transforms for fixed shapes (applied after ParticleSystem initialization)
using TransformType = vnl_matrix_fixed<double, 4, 4>;
std::map<int, TransformType> procrustes_transforms;
int domain_idx = 0;
for (auto s : subjects) {
auto pt = s->get_procrustes_transforms();
for (int d = 0; d < domains_per_shape; d++) {
if (s->is_fixed() && d < pt.size() && pt[d].size() == 16) {
TransformType transform;
int index = 0;
for (int c = 0; c < 4; c++) {
for (int r = 0; r < 4; r++) {
transform[c][r] = pt[d][index++];
}
}
procrustes_transforms[domain_idx] = transform;
}
domain_idx++;
}
}
if (!procrustes_transforms.empty()) {
optimize->SetProcustesTransforms(std::move(procrustes_transforms));
}
}

for (auto s : subjects) {
Expand Down Expand Up @@ -765,7 +789,9 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
}
}

optimize->GetSampler()->GetParticleSystem()->SetPrefixTransform(domain_count++, prefix_transform);
optimize->GetSampler()->GetParticleSystem()->SetPrefixTransform(domain_count, prefix_transform);

domain_count++;

auto name = StringUtils::getBaseFilenameWithoutExtension(filename);

Expand Down
33 changes: 11 additions & 22 deletions Libs/Optimize/ProcrustesRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
namespace shapeworks {

//---------------------------------------------------------------------------
Procrustes3D::ShapeType ProcrustesRegistration::ExtractShape(int domain_index, int num_points) {
Procrustes3D::ShapeType ProcrustesRegistration::ExtractShape(int domain_index, int num_points, bool fully_transformed) {
Procrustes3D::ShapeType shape;
Procrustes3D::PointType point;
for (int j = 0; j < num_points; j++) {
point(0) = m_ParticleSystem->GetPrefixTransformedPosition(j, domain_index)[0];
point(1) = m_ParticleSystem->GetPrefixTransformedPosition(j, domain_index)[1];
point(2) = m_ParticleSystem->GetPrefixTransformedPosition(j, domain_index)[2];
auto pos = fully_transformed ? m_ParticleSystem->GetTransformedPosition(j, domain_index)
: m_ParticleSystem->GetPrefixTransformedPosition(j, domain_index);
point(0) = pos[0];
point(1) = pos[1];
point(2) = pos[2];
shape.push_back(point);
}
return shape;
Expand Down Expand Up @@ -53,36 +55,23 @@ void ProcrustesRegistration::RunFixedDomainRegistration(int domainStart, int num

// Build/rebuild cache if needed (first call or particle count changed after split)
if (!cache.valid || cache.num_points != numPoints) {
// Extract fixed shapes using their full transforms (prefix + existing Procrustes).
// Fixed shapes already have correct transforms; we never modify them.
Procrustes3D::ShapeListType fixed_shapelist;
std::vector<int> fixed_domain_indices;

for (int i = 0, k = domainStart; i < numShapes; i++, k += m_DomainsPerShape) {
if (!is_fixed[i]) continue;
fixed_shapelist.push_back(ExtractShape(k, numPoints));
fixed_domain_indices.push_back(k);
fixed_shapelist.push_back(ExtractShape(k, numPoints, /*fully_transformed=*/true));
}

// Run GPA on fixed shapes only
Procrustes3D::SimilarityTransformListType fixed_transforms;
// Compute mean of the already-aligned fixed shapes (no GPA needed)
Procrustes3D procrustes(m_Scaling, m_RotationTranslation);
procrustes.AlignShapes(fixed_transforms, fixed_shapelist);

// Set transforms for fixed shapes
Procrustes3D::TransformMatrixListType fixed_matrices;
procrustes.ConstructTransformMatrices(fixed_transforms, fixed_matrices);

for (size_t i = 0; i < fixed_domain_indices.size(); i++) {
m_ParticleSystem->SetTransform(fixed_domain_indices[i], fixed_matrices[i]);
}

// Compute and cache the mean of the aligned fixed shapes
// (fixed_shapelist has been modified in-place by AlignShapes to be in Procrustes space)
procrustes.ComputeMeanShape(cache.mean, fixed_shapelist);
cache.num_points = numPoints;
cache.valid = true;

SW_LOG("Procrustes: cached fixed domain mean for domain type {} ({} fixed shapes, {} points)", domainType,
fixed_domain_indices.size(), numPoints);
fixed_shapelist.size(), numPoints);
}

// Align each non-fixed shape to the cached fixed mean using OPA
Expand Down
6 changes: 4 additions & 2 deletions Libs/Optimize/ProcrustesRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ class ProcrustesRegistration {
void RunFixedDomainRegistration(int domainStart, int numShapes, int numPoints,
const std::vector<bool>& is_fixed);

//! Extract prefix-transformed particle positions for a single domain
Procrustes3D::ShapeType ExtractShape(int domain_index, int num_points);
//! Extract particle positions for a single domain.
//! If fully_transformed is true, applies both prefix and Procrustes transforms (world space).
//! If false, applies only the prefix transform (for computing new Procrustes transforms).
Procrustes3D::ShapeType ExtractShape(int domain_index, int num_points, bool fully_transformed = false);

int m_DomainsPerShape = 1;
bool m_Scaling = true;
Expand Down
91 changes: 86 additions & 5 deletions Testing/OptimizeTests/OptimizeTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,93 @@ TEST(OptimizeTests, fixed_domain_procrustes) {
std::cerr << "Eigenvalue " << i << " : " << values[i] << "\n";
}

// With Procrustes scaling enabled, the size variation between spheres should be
// factored out, resulting in a much smaller top eigenvalue compared to the
// fixed_domain test (which has Procrustes disabled and gets >5000).
// The top eigenvalue should be small since all shapes are spheres differing only in scale.
// Fixed shapes keep their existing Procrustes transforms (identity in this test).
// Only the new shape (sphere40) gets a Procrustes transform computed via OPA against
// the fixed mean. Since the test fixed shapes have identity transforms with different
// scales, the eigenvalue will be large (scale variation is not normalized).
// In a real pipeline, fixed shapes would have proper Procrustes transforms from their
// original optimization. Here we just verify the optimization completes successfully.
double value = values[values.size() - 1];
ASSERT_LT(value, 100.0);
ASSERT_GT(value, 0.0);
}

//---------------------------------------------------------------------------
// Test that multiple new (non-fixed) shapes don't interact with each other.
// Running two new shapes together with fixed shapes should produce the same
// result as running each new shape individually with the same fixed shapes.
TEST(OptimizeTests, fixed_domain_independence) {
// Helper lambda: run optimization with specified fixed/excluded/new configuration
// Returns local particles for each domain, indexed by domain index in the project
auto run_optimize = [](const std::string& temp_name,
const std::vector<bool>& is_fixed,
const std::vector<bool>& is_excluded) -> std::vector<std::vector<itk::Point<double>>> {
prep_temp("/optimize/fixed_domain", temp_name);

Optimize app;
ProjectHandle project = std::make_shared<Project>();
EXPECT_TRUE(project->load("optimize.swproj"));

// Reconfigure which subjects are fixed/excluded
auto subjects = project->get_subjects();
for (int i = 0; i < subjects.size(); i++) {
subjects[i]->set_fixed(is_fixed[i]);
subjects[i]->set_excluded(is_excluded[i]);
}

OptimizeParameters params(project);
EXPECT_TRUE(params.set_up_optimize(&app));
bool success = app.Run();
EXPECT_TRUE(success);

return app.GetLocalPoints();
};

// Project has 4 shapes: sphere10, sphere20, sphere30, sphere40
// Run A: sphere10,20 fixed; sphere30,40 both new
auto points_together = run_optimize(
"fixed_domain_indep_together",
{true, true, false, false}, // is_fixed
{false, false, false, false} // is_excluded
);

// Run B: sphere10,20 fixed; sphere30 new; sphere40 excluded
auto points_30_alone = run_optimize(
"fixed_domain_indep_30",
{true, true, false, false}, // is_fixed
{false, false, false, true} // is_excluded: sphere40 excluded
);

// Run C: sphere10,20 fixed; sphere40 new; sphere30 excluded
auto points_40_alone = run_optimize(
"fixed_domain_indep_40",
{true, true, false, false}, // is_fixed
{false, false, true, false} // is_excluded: sphere30 excluded
);

// In run A (together), domains are: 0=sphere10, 1=sphere20, 2=sphere30, 3=sphere40
// In run B (30 alone), domains are: 0=sphere10, 1=sphere20, 2=sphere30
// In run C (40 alone), domains are: 0=sphere10, 1=sphere20, 2=sphere40

// Compare sphere30 particles: run A domain 2 vs run B domain 2
ASSERT_EQ(points_together[2].size(), points_30_alone[2].size());
for (int i = 0; i < points_together[2].size(); i++) {
for (int d = 0; d < 3; d++) {
EXPECT_NEAR(points_together[2][i][d], points_30_alone[2][i][d], 1e-6)
<< "sphere30 particle " << i << " dim " << d << " differs";
}
}

// Compare sphere40 particles: run A domain 3 vs run C domain 2
ASSERT_EQ(points_together[3].size(), points_40_alone[2].size());
for (int i = 0; i < points_together[3].size(); i++) {
for (int d = 0; d < 3; d++) {
EXPECT_NEAR(points_together[3][i][d], points_40_alone[2][i][d], 1e-6)
<< "sphere40 particle " << i << " dim " << d << " differs";
}
}

std::cerr << "Fixed domain independence test passed: new shapes produce identical "
<< "results whether run together or individually\n";
}

//---------------------------------------------------------------------------
Expand Down
Loading