Skip to content

[BUG] WaveDiff TFPhysicalPolychromaticField fails for arbitrary FOV positions due to exact-match lookup in find_position_indices #220

@jeipollack

Description

@jeipollack

Describe the Bug
Inference fails when passing valid FOV positions that are not exactly present in obs_pos.
The failure occurs in:

find_position_indices(obs_pos, batch_positions)

Steps to Reproduce

inference_conf_dataset.yaml
inference:
  # Inference batch size
  batch_size: 16
  
  # Cycle to use for inference. Can be: 1, 2, ...
  cycle: 2

  # Dataset schema used to interpret input fields and control downstream conversion.
  # This does NOT modify dataset content; it only controls how fields are interpreted during conversion.
  # This affects required fields and available conversion handlers.
  #
  # Choices:
  # - INFERENCE: standard inference mode (requires positions and seds)
  # - EVALUATION: evaluation mode (may include additional fields such as sources or masks)
  schema_mode: INFERENCE
  
  # Paths to the configuration files and trained model directory
  configs:
    # Path to the directory containing the trained model
    trained_model_path: /sps/euclid/Users/jpollack/DR1/ppo_a0c6db68/workdir/

    # Subdirectory name of the trained model, e.g. psf_model
    model_subdir: checkpoint
  
    # Relative Path to the training configuration file used to train the model
    trained_model_config_path: config/training_config.yaml

    # Path to the data config file (this could contain prior information)
    data_config_path: config/data_config.yaml

  # The following parameters will overwrite the `model_params` in the training config file.
  model_params:
    # Num of wavelength bins to reconstruct polychromatic objects.
    n_bins_lambda: 20

    # Downsampling rate to match the oversampled model to the specified telescope's sampling.
    output_Q: 3

    # Dimension of the pixel PSF postage stamp
    output_dim: 32

    # Flag to perform centroid error correction
    correct_centroids: True

    # Flag to perform CCD misalignment error correction
    add_ccd_misalignments: True

    # Path to ccd_misalignments file
    ccd_misalignments_aux_path: /path/to/ccd_misalignments/tiles.npy

Code:

# Set path to inference configuration file
inference_config_path = "/path/to/yaml/inference_conf_dataset.yaml"

# Create PSFInference instance

psf_inferred = PSFInference(inference_config_path = inference_config_path,x_field = [-152.5],
                             y_field = [-160.596],
                             seds = seds,
                             sources = data["images"],
                             masks = masks)
                            
psf_inferred.prepare_configs()

Result: TensorFlow assertion fails with error:

assertion failed: [Some positions not found in obs_pos] [Condition x == y did not hold element-wise:] [x (All_1:0) = ] [0] [y (assert_equal_1[/y:0](https://notebook.cc.in2p3.fr/y#line=-1)) = ] [1]
	 [[{{node assert_equal_1[/Assert/AssertGuard/Assert](https://notebook.cc.in2p3.fr/Assert/AssertGuard/Assert)}}]] [Op:__inference_find_position_indices_402192]

Call arguments received by layer 'tf_physical_polychromatic_field_2' (type TFPhysicalPolychromaticField):
  • inputs=['tf.Tensor(shape=(1, 2), dtype=float32)', 'tf.Tensor(shape=(1, 20, 3), dtype=float32)']
  • training=False

Nearest-neighbour check confirms positions are within valid FOV range but do not exactly match any obs_pos entry

Expected Behaviour
Inference should support:

  • arbitrary FOV positions
  • continuous coordinate space

Screenshots
Image

Image

Your Setup
Tested in a Jupyter notebook with a conda environment with TF 2.15.0

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

Status
No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions