diff --git a/lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h b/lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h new file mode 100644 index 0000000000..9bb73fd1a9 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/allowed_machine_views.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H +#define _FLEXFLOW_COMPILER_ALLOWED_MACHINE_VIEWS_H + +#include "pcg/machine_specification.dtg.h" +#include "pcg/machine_view.dtg.h" +#include "pcg/operator_task_space.dtg.h" + +namespace FlexFlow { + +bool is_valid_machine_view(MachineView const &mv, + OperatorTaskSpace const &task, + MachineSpecification const &ms); + +std::unordered_set + get_allowed_machine_views(MachineSpecification const &machine_spec, + OperatorTaskSpace const &task, + DeviceType device_type); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h b/lib/compiler/include/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h new file mode 100644 index 0000000000..b08ca57851 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h @@ -0,0 +1,32 @@ +#ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_AND_UPDATE_MACHINE_MAPPING_H +#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_AND_UPDATE_MACHINE_MAPPING_H + +#include "compiler/search_result.dtg.h" +#include "substitutions/pcg_pattern_match.dtg.h" +#include "substitutions/sub_parallel_computation_graph.dtg.h" +#include "substitutions/substitution.dtg.h" + +namespace FlexFlow { +/** + * @brief Applies \p substitution to \p mapped_pcg at the location specified by + * \p match, returning the resulting SearchResult (mapped pcg) + * + * @param mapped_pcg + * @param substitution + * @param match The location at which to apply substitution. This location in + * sub_pcg should match substitution's PCGPattern. Likely created by running + * FlexFlow::find_pattern_matches(PCGPattern const &, + * SubParallelComputationGraph const &). + * @return SearchResult A mapped pcg similar to mapped_pcg, but with + * the subgraph of the pcg specified by match replaced with the result of the + * output expression of substitution and the machine mapping updated to account + * for the new output + */ +SearchResult apply_substitution_and_update_machine_mapping( + SearchResult const &mapped_pcg, + Substitution const &sub, + PCGPatternMatch const &match); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h b/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h new file mode 100644 index 0000000000..16385a74e8 --- /dev/null +++ b/lib/compiler/include/compiler/machine_mapping/machine_mapping_mutation_set.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MCMC_MACHINE_MAPPING_MUTATION_SET_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_MCMC_MACHINE_MAPPING_MUTATION_SET_H + +#include "compiler/machine_mapping/machine_mapping.h" +#include "compiler/search_result.dtg.h" + +namespace FlexFlow { +std::optional + get_random_mapping(ParallelComputationGraph &pcg, + MachineSpecification const &resources, + DeviceType const &device_type); + +std::optional + get_random_mutation(SearchResult mapped_pcg, + MachineSpecification const &resource, + DeviceType const &device_type); +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h b/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h new file mode 100644 index 0000000000..a3baa251e3 --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/generic_mcmc_algorithm.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_ALGORITHM_H +#define _FLEXFLOW_COMPILER_MCMC_GENERIC_MCMC_ALGORITHM_H + +#include "compiler/mcmc/generic_mcmc_config.dtg.h" +#include "utils/containers/transform.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/optional.h" +#include "utils/random_utils.h" + +namespace FlexFlow { + +// SamplingFn : State -> std::optional +// CostFn : State -> float + +template +State run_mcmc(State const &starting_state, + SamplingFn const &sampler, + CostFn const &cost, + GenericMCMCConfig const &search_config) { + State best_state = starting_state; + State current_state = best_state; + for (nonnegative_int i : nonnegative_range(search_config.num_iterations)) { + std::optional maybe_new_state = + transform(sampler(current_state), [&](State const &s) { + float delta = cost(s) - cost(best_state); + if (randf() < exp(-delta / search_config.temperature)) { + if (delta < 0) { + best_state = s; + } + return s; + } + return current_state; + }); + current_state = or_else(maybe_new_state, [&]() { return current_state; }); + } + return best_state; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/mcmc/generic_mcmc_config.struct.toml b/lib/compiler/include/compiler/mcmc/generic_mcmc_config.struct.toml new file mode 100644 index 0000000000..e11c84f0bd --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/generic_mcmc_config.struct.toml @@ -0,0 +1,19 @@ +namespace = "FlexFlow" +name = "GenericMCMCConfig" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/nonnegative_int/nonnegative_int.h" +] + +[[fields]] +name = "temperature" +type = "float" + +[[fields]] +name = "num_iterations" +type = "::FlexFlow::nonnegative_int" \ No newline at end of file diff --git a/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h new file mode 100644 index 0000000000..c251340626 --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_COMPILER_MCMC_OVER_MAPPED_PCG_H +#define _FLEXFLOW_COMPILER_MCMC_OVER_MAPPED_PCG_H + +#include "compiler/cost_estimator/runtime_only_cost_estimator.h" +#include "compiler/mcmc/mcmc_over_mapped_pcg_config.dtg.h" +#include "compiler/search_result.dtg.h" +#include "pcg/computation_graph.h" +#include "pcg/machine_specification.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/substitution.h" + +namespace FlexFlow { + +SearchResult + mcmc_over_mapped_pcg(ParallelComputationGraph &pcg, + RuntimeOnlyCostEstimator const &cost_estimator, + MachineSpecification const &resources, + MCMCOverMappedPCGConfig const &search_config); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml new file mode 100644 index 0000000000..76415ee4d9 --- /dev/null +++ b/lib/compiler/include/compiler/mcmc/mcmc_over_mapped_pcg_config.struct.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "MCMCOverMappedPCGConfig" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "pcg/device_type.dtg.h", + "utils/nonnegative_int/nonnegative_int.h" +] + +[[fields]] +name = "temperature" +type = "float" + +[[fields]] +name = "num_iterations" +type = "::FlexFlow::nonnegative_int" + +[[fields]] +name = "substitution_frequency" +type = "float" + +[[fields]] +name = "device_type" +type = "::FlexFlow::DeviceType" \ No newline at end of file diff --git a/lib/compiler/include/compiler/search_result.h b/lib/compiler/include/compiler/search_result.h new file mode 100644 index 0000000000..197b36e9ea --- /dev/null +++ b/lib/compiler/include/compiler/search_result.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H +#define _FLEXFLOW_LIB_COMPILER_INCLUDE_COMPILER_GRAPH_OPTIMIZE_RESULT_H + +#include "compiler/search_result.dtg.h" + +namespace FlexFlow { + +std::string format_as(SearchResult const &); +std::ostream &operator<<(std::ostream &, SearchResult const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/compiler/include/compiler/search_result.struct.toml b/lib/compiler/include/compiler/search_result.struct.toml new file mode 100644 index 0000000000..7e7e59d7c9 --- /dev/null +++ b/lib/compiler/include/compiler/search_result.struct.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SearchResult" +features = [ +] + +includes = [ + "pcg/parallel_computation_graph/parallel_computation_graph.h", + "compiler/machine_mapping/machine_mapping.h", +] + +[[fields]] +name = "pcg" +type = "::FlexFlow::ParallelComputationGraph" + +[[fields]] +name = "machine_mapping" +type = "::FlexFlow::MachineMapping" \ No newline at end of file diff --git a/lib/compiler/src/compiler/allowed_machine_views.cc b/lib/compiler/src/compiler/allowed_machine_views.cc index 370cb5a4ec..64b910bf7d 100644 --- a/lib/compiler/src/compiler/allowed_machine_views.cc +++ b/lib/compiler/src/compiler/allowed_machine_views.cc @@ -57,6 +57,8 @@ static std::unordered_set product(transform(tensor_dims, [](positive_int num_devices) { return nonnegative_int{num_devices.int_from_positive_int() - 1}; })); + min_num_devices_with_full_stride_volume = + std::max(min_num_devices_with_full_stride_volume, 1_n); return ceildiv(total_devices, positive_int{min_num_devices_with_full_stride_volume}); }; diff --git a/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc new file mode 100644 index 0000000000..2cb78a38b6 --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc @@ -0,0 +1,78 @@ +#include "compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" +#include "substitutions/apply_substitution/apply_substitution.h" +#include "substitutions/apply_substitution/evaluate_substitution_output.h" +#include "substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.h" +#include "substitutions/open_parallel_tensor_guid_t.h" +#include "substitutions/pcg_pattern_match.h" +#include "substitutions/sub_parallel_computation_graph.h" +#include "substitutions/sub_parallel_computation_graph_data.dtg.h" +#include "substitutions/sub_parallel_computation_graph_edge.h" +#include "utils/containers/filter.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/merge_maps.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/random_utils.h" +#include + +namespace FlexFlow { + +SearchResult apply_substitution_and_update_machine_mapping( + SearchResult const &mapped_pcg, + Substitution const &sub, + PCGPatternMatch const &match) { + SubParallelComputationGraph spcg = sub_pcg_from_full_pcg(mapped_pcg.pcg); + + std::pair + substitution_output_result = + evaluate_substitution_output(spcg, sub, match); + + SubParallelComputationGraph post_substitution_graph = + apply_substitution_from_output_result( + substitution_output_result, spcg, sub, match); + + std::unordered_map post_node_data = + get_sub_pcg_data(post_substitution_graph).node_data; + + std::unordered_set + substitution_output_parallel_layers = + get_parallel_layers(substitution_output_result.first); + + std::unordered_map machine_views = + mapped_pcg.machine_mapping.machine_views; + + std::unordered_set matched_nodes = + unordered_set_of(values(match.node_assignment)); + + std::vector substituted_machine_views = vector_of( + transform(matched_nodes, [&](parallel_layer_guid_t const &node) { + return machine_views.at(node); + })); + + for (parallel_layer_guid_t layer : substitution_output_parallel_layers) { + machine_views.insert_or_assign(layer, + select_random(substituted_machine_views)); + } + + ASSERT(is_subseteq_of(keys(post_node_data), keys(machine_views))); + + std::unordered_map + post_node_machine_views = + filter(machine_views, + [&](std::pair const &p) { + return post_node_data.count(p.first); + }); + + ASSERT(keys(post_node_data) == keys(post_node_machine_views)); + + return SearchResult{ + pcg_from_sub_pcg_by_dropping_inputs(post_substitution_graph), + MachineMapping{post_node_machine_views}}; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc new file mode 100644 index 0000000000..c3c84bb24a --- /dev/null +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_mutation_set.cc @@ -0,0 +1,52 @@ +#include "compiler/machine_mapping/machine_mapping_mutation_set.h" +#include "compiler/allowed_machine_views.h" +#include "pcg/machine_view.h" +#include "pcg/operator_task_space.h" +#include "utils/containers/vector_of.h" +#include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/random_utils.h" + +namespace FlexFlow { + +std::optional + get_random_mapping(ParallelComputationGraph &pcg, + MachineSpecification const &resources, + DeviceType const &device_type) { + std::vector layers = topological_ordering(pcg); + std::unordered_map machine_views; + for (parallel_layer_guid_t layer : layers) { + OperatorTaskSpace task = get_operator_task_space(pcg, layer); + std::unordered_set allowed_machine_views = + get_allowed_machine_views(resources, task, DeviceType::GPU); + if (allowed_machine_views.empty()) { + return std::nullopt; + } + machine_views.insert( + {layer, select_random(vector_of(allowed_machine_views))}); + } + return MachineMapping{machine_views}; +} + +std::optional + get_random_mutation(SearchResult mapped_pcg, + MachineSpecification const &resources, + DeviceType const &device_type) { + ParallelComputationGraph pcg = mapped_pcg.pcg; + std::vector layers = topological_ordering(pcg); + if (layers.size() == 0) { + return std::nullopt; + } + parallel_layer_guid_t random_layer = select_random(layers); + + MachineMapping machine_mapping = mapped_pcg.machine_mapping; + MachineView machine_view = machine_mapping.machine_views.at(random_layer); + OperatorTaskSpace task = get_operator_task_space(pcg, random_layer); + + std::vector allowed_machine_views = + vector_of(get_allowed_machine_views(resources, task, device_type)); + MachineView random_new_machine_view = select_random(allowed_machine_views); + + machine_mapping.machine_views.at(random_layer) = random_new_machine_view; + return machine_mapping; +} +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc b/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc new file mode 100644 index 0000000000..2c8fcea86d --- /dev/null +++ b/lib/compiler/src/compiler/mcmc/generic_mcmc_algorithm.cc @@ -0,0 +1,15 @@ +#include "compiler/mcmc/generic_mcmc_algorithm.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using State = value_type<0>; +using SamplingFn = std::function(State)>; +using CostFn = std::function; + +template State run_mcmc(State const &starting_state, + SamplingFn const &sampler, + CostFn const &cost, + GenericMCMCConfig const &search_config); + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc b/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc new file mode 100644 index 0000000000..43e80630dd --- /dev/null +++ b/lib/compiler/src/compiler/mcmc/mcmc_over_mapped_pcg.cc @@ -0,0 +1,67 @@ +#include "compiler/mcmc/mcmc_over_mapped_pcg.h" +#include "compiler/machine_mapping/apply_substitution_and_update_machine_mapping.h" +#include "compiler/machine_mapping/machine_mapping_mutation_set.h" +#include "compiler/mcmc/generic_mcmc_algorithm.h" +#include "compiler/search_result.h" +#include "compiler/task_graph_simulator/task_simulator.h" +#include "substitutions/pcg_pattern.h" +#include "substitutions/pcg_pattern_match.h" +#include "substitutions/unity_substitution_set.h" +#include "utils/optional.h" +#include "utils/random_utils.h" +#include + +namespace FlexFlow { + +SearchResult + mcmc_over_mapped_pcg(ParallelComputationGraph &pcg, + RuntimeOnlyCostEstimator const &cost_estimator, + MachineSpecification const &resources, + MCMCOverMappedPCGConfig const &search_config) { + + std::vector substitutions = get_substitution_set(resources); + MachineMapping random_mapping = assert_unwrap( + get_random_mapping(pcg, resources, search_config.device_type)); + SearchResult starting_state = SearchResult{pcg, random_mapping}; + + auto sampler = [&](SearchResult mapped_pcg) -> std::optional { + // applies substitution with substitution_frequency probability + // applies machine mapping mutation with (1 - substitution_frequency) + // probability + ASSERT(search_config.substitution_frequency >= 0 && + search_config.substitution_frequency <= 1); + if (randf() < search_config.substitution_frequency) { + Substitution random_substitution = + assert_unwrap(get_random_substitution(resources)); + std::optional maybe_pattern_match = + get_random_pattern_match(random_substitution.pcg_pattern, + sub_pcg_from_full_pcg(mapped_pcg.pcg)); + return transform(maybe_pattern_match, [&](PCGPatternMatch match) { + return apply_substitution_and_update_machine_mapping( + mapped_pcg, random_substitution, match); + }); + } else { + MachineMapping new_machine_mapping = assert_unwrap(get_random_mutation( + mapped_pcg, resources, search_config.device_type)); + return SearchResult{mapped_pcg.pcg, new_machine_mapping}; + } + }; + + auto cost = [&](SearchResult mapped_pcg) -> float { + return task_simulator_estimate_forward_pass_time(mapped_pcg.pcg, + cost_estimator, + mapped_pcg.machine_mapping, + resources) + .unwrap_milliseconds(); + }; + + GenericMCMCConfig config = + GenericMCMCConfig{/*temperature*/ search_config.temperature, + /*num_iterations*/ search_config.num_iterations}; + + SearchResult result = run_mcmc(starting_state, sampler, cost, config); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/compiler/src/compiler/search_result.cc b/lib/compiler/src/compiler/search_result.cc new file mode 100644 index 0000000000..0afc10723a --- /dev/null +++ b/lib/compiler/src/compiler/search_result.cc @@ -0,0 +1,15 @@ +#include "compiler/search_result.h" + +namespace FlexFlow { + +std::string format_as(SearchResult const &r) { + return fmt::format("", + as_dot(r.pcg), + r.machine_mapping); +} + +std::ostream &operator<<(std::ostream &s, SearchResult const &r) { + return (s << fmt::to_string(r)); +} + +} // namespace FlexFlow diff --git a/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc b/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc new file mode 100644 index 0000000000..b21ee4333f --- /dev/null +++ b/lib/compiler/test/src/compiler/mcmc/generic_mcmc_algorithm.cc @@ -0,0 +1,29 @@ +#include "compiler/mcmc/generic_mcmc_algorithm.h" +#include "doctest/doctest.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("generic_mcmc_algorithm") { + float starting_state = 0.1; + auto sampler = [](float x) -> std::optional { + float new_x = x + (randf() - 0.5); + if (new_x < 0) { + return std::nullopt; + } + if (new_x > 1) { + return std::nullopt; + } + return new_x; + }; + auto cost = [](float x) { return (x - 0.5) * (x - 0.5); }; + GenericMCMCConfig config = GenericMCMCConfig{/*temperature=*/1.0, + /*num_iterations=*/100_n}; + float answer = run_mcmc(starting_state, sampler, cost, config); + float error = cost(answer); + CHECK(answer > 0.47); + CHECK(answer < 0.53); + CHECK(error >= 0); + CHECK(error < 0.001); + } +} diff --git a/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc b/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc new file mode 100644 index 0000000000..9e2134d08b --- /dev/null +++ b/lib/compiler/test/src/compiler/mcmc/mcmc_over_mapped_pcg.cc @@ -0,0 +1,94 @@ +#include "compiler/mcmc/mcmc_over_mapped_pcg.h" +#include "compiler/task_graph_simulator/task_simulator.h" +#include "doctest/doctest.h" +#include "internal/runtime_only_cost_estimator_for_test.h" +#include "op-attrs/parallel_tensor_dims.h" +#include "op-attrs/parallel_tensor_shape.dtg.h" +#include "op-attrs/replica_type.dtg.h" +#include "op-attrs/shard_parallel_dim.h" +#include "pcg/computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/pcg_from_computation_graph.h" +#include "utils/integer_conversions.h" + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("mcmc_over_mapped_pcg") { + ComputationGraph cg = [&] { + ComputationGraphBuilder b; + TensorShape input_tensor_shape = TensorShape{ + TensorDims{ + FFOrdered{32_p, 64_p}, + }, + DataType::FLOAT, + }; + tensor_guid_t t = b.create_input(input_tensor_shape, CreateGrad::YES); + t = b.dense(t, + /*outDim=*/16_p, + /*activation=*/std::nullopt); + t = b.gelu(t); + t = b.dense(t, + /*outDim=*/12_p, + /*activation=*/std::nullopt, + /*use_bias=*/false, + /*data_type=*/DataType::FLOAT, + /*kernel_initializer=*/std::nullopt, + /*bias_initializer=*/std::nullopt); + t = b.relu(t); + t = b.dense(t, + /*outDim=*/8_p, + /*activation=*/Activation::RELU); + return b.computation_graph; + }(); + + ParallelComputationGraph pcg = pcg_from_computation_graph(cg); + + RuntimeOnlyCostEstimator cost_estimator = + make_fake_constant_runtime_only_cost_estimator( + /*forward_op_cost=*/10_ms, + /*backward_op_cost=*/10_ms, + /*comm_cost=*/1_ms); + + MachineSpecification full_machine_spec = MachineSpecification{ + /*num_nodes=*/2_p, + /*num_cpus_per_node=*/1_p, + /*num_gpus_per_node=*/1_p, + /*inter_node_bandwidth=*/1, + /*intra_node_bandwidth=*/1, + }; + + MCMCOverMappedPCGConfig no_search = + MCMCOverMappedPCGConfig{/*temperature=*/1.0, + /*num_iterations=*/1_n, + /*substitution_frequency=*/0.2, + /*device_type=*/DeviceType::GPU}; + + SearchResult base_result = + mcmc_over_mapped_pcg(pcg, cost_estimator, full_machine_spec, no_search); + float base_runtime = + task_simulator_estimate_forward_pass_time(base_result.pcg, + cost_estimator, + base_result.machine_mapping, + full_machine_spec) + .unwrap_milliseconds(); + + MCMCOverMappedPCGConfig search_config = + MCMCOverMappedPCGConfig{/*temperature=*/1.0, + /*num_iterations=*/100_n, + /*substitution_frequency=*/0.2, + /*device_type=*/DeviceType::GPU}; + + SearchResult result = mcmc_over_mapped_pcg( + pcg, cost_estimator, full_machine_spec, search_config); + float runtime = + task_simulator_estimate_forward_pass_time(result.pcg, + cost_estimator, + result.machine_mapping, + full_machine_spec) + .unwrap_milliseconds(); + + CHECK(runtime < base_runtime); + CHECK(runtime < 100); + } +} diff --git a/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h b/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h index 92f7bb1c03..d46523ecb6 100644 --- a/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h +++ b/lib/substitutions/include/substitutions/apply_substitution/apply_substitution.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_H #define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_APPLY_SUBSTITUTION_APPLY_SUBSTITUTION_H +#include "substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.h" #include "substitutions/pcg_pattern_match.dtg.h" #include "substitutions/sub_parallel_computation_graph.dtg.h" #include "substitutions/substitution.dtg.h" @@ -26,6 +27,13 @@ SubParallelComputationGraph Substitution const &substitution, PCGPatternMatch const &match); +SubParallelComputationGraph apply_substitution_from_output_result( + std::pair + substitution_output_result, + SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match); + } // namespace FlexFlow #endif diff --git a/lib/substitutions/include/substitutions/pcg_pattern.h b/lib/substitutions/include/substitutions/pcg_pattern.h index f0962b15c2..5005a0b51c 100644 --- a/lib/substitutions/include/substitutions/pcg_pattern.h +++ b/lib/substitutions/include/substitutions/pcg_pattern.h @@ -12,6 +12,10 @@ namespace FlexFlow { std::unordered_set get_nodes(PCGPattern const &); +std::optional + get_random_pattern_match(PCGPattern const &pattern, + SubParallelComputationGraph const &pcg); + /** * @brief Find all locations in \p pcg that match \p pattern */ diff --git a/lib/substitutions/include/substitutions/unity_substitution_set.h b/lib/substitutions/include/substitutions/unity_substitution_set.h index 183f76ac8a..574dd9da3d 100644 --- a/lib/substitutions/include/substitutions/unity_substitution_set.h +++ b/lib/substitutions/include/substitutions/unity_substitution_set.h @@ -7,6 +7,9 @@ namespace FlexFlow { +std::optional + get_random_substitution(MachineSpecification const &resources); + std::vector get_substitution_set(MachineSpecification const &resources); diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index 61bfe15d7b..611296488e 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -20,8 +20,19 @@ SubParallelComputationGraph apply_substitution(SubParallelComputationGraph const &spcg, Substitution const &sub, PCGPatternMatch const &match) { - auto substitution_output_result = - evaluate_substitution_output(spcg, sub, match); + std::pair + substitution_output_result = + evaluate_substitution_output(spcg, sub, match); + return apply_substitution_from_output_result( + substitution_output_result, spcg, sub, match); +} + +SubParallelComputationGraph apply_substitution_from_output_result( + std::pair + substitution_output_result, + SubParallelComputationGraph const &spcg, + Substitution const &sub, + PCGPatternMatch const &match) { SubParallelComputationGraph substitution_output_graph = substitution_output_result.first; OutputExprToResultSubPCGMapping output_expr_to_result_sub_pcg_mapping = diff --git a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc index 194ae49255..4f61eac113 100644 --- a/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/operator_pattern/satisfies_constraint.cc @@ -16,6 +16,34 @@ bool operator_satisfies_constraint( switch (constraint.constraint_type) { case ConstraintType::EQUAL: return expr_val.value() == constraint.attribute_value; + case ConstraintType::DIVISIBLE_BY: { + auto get_nonnegative_int_if_possible = + [](OperatorAttributeValue v) -> std::optional { + if (v.has()) { + return v.get(); + } + if (v.has()) { + return v.get().nonnegative_int_from_positive_int(); + } + return std::nullopt; + }; + + if (!expr_val.has_value()) { + throw mk_runtime_error("DIVISIBLE_BY constraint requires " + "nonnegative_int or positive_int values"); + } + + std::optional maybe_expr_val_nn = + get_nonnegative_int_if_possible(expr_val.value()); + std::optional maybe_attr_val_nn = + get_nonnegative_int_if_possible(constraint.attribute_value); + + if (maybe_expr_val_nn.has_value() && maybe_attr_val_nn.has_value()) { + return maybe_expr_val_nn.value() % maybe_attr_val_nn.value() == 0; + } + throw mk_runtime_error("DIVISIBLE_BY constraint requires nonnegative_int " + "or positive_int values"); + } default: throw mk_runtime_error( fmt::format("Unknown constraint type {}", diff --git a/lib/substitutions/src/substitutions/pcg_pattern.cc b/lib/substitutions/src/substitutions/pcg_pattern.cc index a0af875848..1e260f9fe3 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern.cc @@ -11,6 +11,7 @@ #include "utils/graph/node/algorithms.h" #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" #include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_graph_inputs.h" +#include "utils/random_utils.h" namespace FlexFlow { @@ -20,6 +21,17 @@ std::unordered_set get_nodes(PCGPattern const &p) { return transform(raw_nodes, [](Node const &n) { return PatternNode{n}; }); } +std::optional + get_random_pattern_match(PCGPattern const &pattern, + SubParallelComputationGraph const &pcg) { + std::vector pattern_matches = + find_pattern_matches(pattern, pcg); + if (pattern_matches.empty()) { + return std::nullopt; + } + return select_random(pattern_matches); +} + static MatchAdditionalCriterion pcg_pattern_criteria(PCGPattern const &pattern, SubParallelComputationGraph const &pcg) { diff --git a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc index 974bfcabc0..cc0af12c91 100644 --- a/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc +++ b/lib/substitutions/src/substitutions/tensor_pattern/satisfies_constraint.cc @@ -12,6 +12,16 @@ bool parallel_tensor_satisfies_constraint( switch (constraint.constraint_type) { case ConstraintType::EQUAL: return expr_val == constraint.attribute_value; + case ConstraintType::DIVISIBLE_BY: { + if (expr_val.has() && + constraint.attribute_value.has()) { + return expr_val.get() % + constraint.attribute_value.get() == + 0; + } + throw mk_runtime_error( + "DIVISIBLE_BY constraint requires nonnegative_int values"); + } default: throw mk_runtime_error( fmt::format("Unknown constraint type {}", diff --git a/lib/substitutions/src/substitutions/unity_substitution_set.cc b/lib/substitutions/src/substitutions/unity_substitution_set.cc index 4b00cdd95f..c8d9266978 100644 --- a/lib/substitutions/src/substitutions/unity_substitution_set.cc +++ b/lib/substitutions/src/substitutions/unity_substitution_set.cc @@ -7,9 +7,19 @@ #include "utils/containers/get_only.h" #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/nonnegative_int/nonnegative_range.h" +#include "utils/random_utils.h" namespace FlexFlow { +std::optional + get_random_substitution(MachineSpecification const &resources) { + std::vector substitutions = get_substitution_set(resources); + if (substitutions.empty()) { + return std::nullopt; + } + return select_random(substitutions); +} + std::vector get_substitution_set(MachineSpecification const &resources) { std::vector substitutions; diff --git a/lib/utils/include/utils/optional.h b/lib/utils/include/utils/optional.h index 377561d70c..4e4bc03cd4 100644 --- a/lib/utils/include/utils/optional.h +++ b/lib/utils/include/utils/optional.h @@ -3,6 +3,7 @@ #include "utils/exception.h" #include "utils/fmt/optional.h" +#include #include namespace FlexFlow { @@ -28,7 +29,7 @@ T const &unwrap(std::optional const &o, F const &f) { template T const &assert_unwrap(std::optional const &o) { - assert(o.has_value()); + ASSERT(o.has_value()); return o.value(); } diff --git a/lib/utils/include/utils/random_utils.h b/lib/utils/include/utils/random_utils.h index 99da9646a1..014c38fc51 100644 --- a/lib/utils/include/utils/random_utils.h +++ b/lib/utils/include/utils/random_utils.h @@ -5,7 +5,7 @@ #include #include -float randf() { +inline float randf() { return static_cast(std::rand()) / static_cast(RAND_MAX); }