Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ OperatorAttributeConstraint op_type_equals_constraint(OperatorType);

OperatorAttributeConstraint op_attr_key_equals(OperatorAttributeKey,
OperatorAttributeValue const &);
OperatorAttributeConstraint
op_attr_key_divisible_by(OperatorAttributeKey, nonnegative_int denominator);
OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey,
positive_int denominator);
OperatorAttributeConstraint
make_equals_constraint(OperatorAttributeExpr const &,
OperatorAttributeValue const &);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
#define _FLEXFLOW_LIB_SUBSTITUTIONS_INCLUDE_SUBSTITUTIONS_TENSOR_PATTERN_TENSOR_ATTRIBUTE_PATTERN_H

#include "substitutions/tensor_pattern/tensor_attribute_pattern.dtg.h"
#include "utils/nonnegative_int/nonnegative_int.h"
#include "utils/positive_int/positive_int.h"

namespace FlexFlow {

TensorAttributePattern tensor_attribute_pattern_match_all();
TensorAttributePattern
tensor_attr_pattern_require_num_dims(nonnegative_int num_dims);
tensor_attr_pattern_require_num_dims(positive_int num_dims);

} // namespace FlexFlow

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@ includes = [
"utils/hash/vector.h",
"utils/fmt/vector.h",
"utils/nonnegative_int/nonnegative_int.h",
"utils/positive_int/positive_int.h",
]

[[values]]
type = "::FlexFlow::nonnegative_int"

[[values]]
type = "::FlexFlow::positive_int"

[[values]]
type = "std::vector<::FlexFlow::nonnegative_int>"

[[values]]
type = "std::vector<::FlexFlow::positive_int>"

37 changes: 13 additions & 24 deletions lib/substitutions/include/substitutions/unity_substitution_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,25 @@ namespace FlexFlow {
std::vector<Substitution>
get_substitution_set(MachineSpecification const &resources);

Substitution create_combine_inception(nonnegative_int num_convs,
nonnegative_int num_dims,
nonnegative_int degree);
Substitution create_combine_concat(nonnegative_int num_inputs,
nonnegative_int num_dims,
nonnegative_int degree);
Substitution create_replicate_linear_combine(nonnegative_int num_dims,
nonnegative_int degree,
Substitution create_replicate_linear_combine(positive_int num_dims,
positive_int degree,
bool use_bias);
Substitution create_partition_linear_combine(nonnegative_int num_dims,
nonnegative_int degree,
Activation activation,
Substitution create_partition_linear_combine(positive_int num_dims,
positive_int degree,
bool use_bias);
Substitution create_partition_conv2d_combine(nonnegative_int num_dims,
nonnegative_int degree);
Substitution create_partition_attention_combine(nonnegative_int num_heads,
nonnegative_int degree);
Substitution create_replicate_attention_reduce(nonnegative_int num_heads,
nonnegative_int degree);
Substitution create_partition_conv2d_combine(positive_int num_dims,
positive_int degree);
Substitution create_partition_attention_combine(positive_int num_heads,
positive_int degree);
Substitution create_replicate_attention_reduce(positive_int num_heads,
positive_int degree);
Substitution create_partition_add_combine(ff_dim_t parallel_dim,
nonnegative_int degree);
positive_int degree);
Substitution create_partition_relu_combine(ff_dim_t parallel_dim,
nonnegative_int degree);
Substitution create_partition_concat_combine(nonnegative_int num_inputs,
ff_dim_t concat_dim,
ff_dim_t parallel_dim,
nonnegative_int degree);
positive_int degree);
Substitution create_partition_softmax_combine(ff_dim_t softmax_dim,
ff_dim_t partition_dim,
nonnegative_int degree);
positive_int degree);
Substitution create_fuse_linear_activation(Activation activation);

} // namespace FlexFlow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ std::optional<OperatorAttributeValue> get_attribute(ConcatAttrs const &p,
std::optional<OperatorAttributeValue> get_attribute(Conv2DAttrs const &p,
OperatorAttributeKey key) {
switch (key) {
case OperatorAttributeKey::OUT_CHANNELS:
return OperatorAttributeValue{p.out_channels};
case OperatorAttributeKey::OP_TYPE:
return OperatorAttributeValue{get_op_type(p)};
case OperatorAttributeKey::KERNEL_H:
Expand Down Expand Up @@ -113,6 +115,12 @@ std::optional<OperatorAttributeValue> get_attribute(ElementBinaryAttrs const &p,
switch (key) {
case OperatorAttributeKey::OP_TYPE:
return OperatorAttributeValue{get_op_type(p)};
case OperatorAttributeKey::DATA_TYPE:
return OperatorAttributeValue{p.compute_type};
case OperatorAttributeKey::SHOULD_BROADCAST_LHS:
return OperatorAttributeValue{p.should_broadcast_lhs};
case OperatorAttributeKey::SHOULD_BROADCAST_RHS:
return OperatorAttributeValue{p.should_broadcast_rhs};
default:
return std::nullopt;
}
Expand All @@ -123,6 +131,8 @@ std::optional<OperatorAttributeValue> get_attribute(ElementUnaryAttrs const &p,
switch (key) {
case OperatorAttributeKey::OP_TYPE:
return OperatorAttributeValue{get_op_type(p)};
case OperatorAttributeKey::SCALAR:
return OperatorAttributeValue{p.scalar};
default:
return std::nullopt;
}
Expand Down Expand Up @@ -227,10 +237,20 @@ std::optional<OperatorAttributeValue>
switch (key) {
case OperatorAttributeKey::OP_TYPE:
return OperatorAttributeValue{get_op_type(p)};
case OperatorAttributeKey::EMBED_DIM:
return OperatorAttributeValue{p.embed_dim};
case OperatorAttributeKey::KDIM:
return OperatorAttributeValue{p.kdim};
case OperatorAttributeKey::VDIM:
return OperatorAttributeValue{p.vdim};
case OperatorAttributeKey::NUM_HEADS:
return OperatorAttributeValue{p.num_heads};
case OperatorAttributeKey::USE_BIAS:
case OperatorAttributeKey::BIAS:
return OperatorAttributeValue{p.bias};
case OperatorAttributeKey::ADD_BIAS_KV:
return OperatorAttributeValue{p.add_bias_kv};
case OperatorAttributeKey::ADD_ZERO_ATTN:
return OperatorAttributeValue{p.add_bias_kv};
case OperatorAttributeKey::DROPOUT:
return OperatorAttributeValue{p.dropout};
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ OperatorAttributeConstraint
};
}

OperatorAttributeConstraint
op_attr_key_divisible_by(OperatorAttributeKey key,
nonnegative_int denominator) {
OperatorAttributeConstraint op_attr_key_divisible_by(OperatorAttributeKey key,
positive_int denominator) {
return OperatorAttributeConstraint{
ConstraintType::DIVISIBLE_BY,
OperatorAttributeExpr{key},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
case OperatorType::NOOP:
case OperatorType::INPUT:
case OperatorType::WEIGHT:
case OperatorType::CONV2D:
case OperatorType::DROPOUT:
case OperatorType::LINEAR:
return PCGOperatorAttrs{LinearAttrs{
Expand All @@ -75,19 +74,72 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
acc.get<std::optional<RegularizerAttrs>>(
OperatorAttributeKey::REGULARIZER),
}};
case OperatorType::CONV2D:
return PCGOperatorAttrs{Conv2DAttrs{
/*out_channels=*/acc.get<positive_int>(
OperatorAttributeKey::OUT_CHANNELS),
/*kernel_h=*/acc.get<positive_int>(OperatorAttributeKey::KERNEL_H),
/*kernel_w=*/acc.get<positive_int>(OperatorAttributeKey::KERNEL_W),
/*stride_h=*/acc.get<positive_int>(OperatorAttributeKey::STRIDE_H),
/*stride_w=*/acc.get<positive_int>(OperatorAttributeKey::STRIDE_W),
/*padding_h=*/
acc.get<nonnegative_int>(OperatorAttributeKey::PADDING_H),
/*padding_w=*/
acc.get<nonnegative_int>(OperatorAttributeKey::PADDING_W),
/*groups=*/acc.get<positive_int>(OperatorAttributeKey::GROUPS),
/*activation=*/
acc.get<std::optional<Activation>>(OperatorAttributeKey::ACTIVATION),
/*use_bias=*/acc.get<bool>(OperatorAttributeKey::USE_BIAS),
}};
case OperatorType::RELU:
return PCGOperatorAttrs{ElementUnaryAttrs{
acc.get<OperatorType>(OperatorAttributeKey::OP_TYPE),
acc.get<std::optional<float>>(OperatorAttributeKey::SCALAR),
}};
case OperatorType::SOFTMAX:
return PCGOperatorAttrs{SoftmaxAttrs{
acc.get<ff_dim_t>(OperatorAttributeKey::AXIS),
}};
case OperatorType::EW_ADD:
return PCGOperatorAttrs{ElementBinaryAttrs{
acc.get<OperatorType>(OperatorAttributeKey::OP_TYPE),
acc.get<DataType>(OperatorAttributeKey::DATA_TYPE),
acc.get<bool>(OperatorAttributeKey::SHOULD_BROADCAST_LHS),
acc.get<bool>(OperatorAttributeKey::SHOULD_BROADCAST_LHS),
}};
case OperatorType::REPLICATE:
return PCGOperatorAttrs{ReplicateAttrs{
/*replicate_degree=*/acc.get<positive_int>(
OperatorAttributeKey::PARALLEL_DEGREE),
}};
case OperatorType::REPARTITION:
return PCGOperatorAttrs{RepartitionAttrs{
/*repartition_dim=*/acc.get<ff_dim_t>(
OperatorAttributeKey::PARALLEL_DIM),
/*repartition_Degree=*/
acc.get<positive_int>(OperatorAttributeKey::PARALLEL_DEGREE),
}};
case OperatorType::COMBINE:
return PCGOperatorAttrs{CombineAttrs{
/*combine_dim=*/acc.get<ff_dim_t>(OperatorAttributeKey::PARALLEL_DIM),
/*combine_degree=*/
acc.get<positive_int>(OperatorAttributeKey::PARALLEL_DEGREE),
}};
case OperatorType::REDUCTION:
return PCGOperatorAttrs{ReductionAttrs{
acc.get<positive_int>(OperatorAttributeKey::PARALLEL_DEGREE),
}};
case OperatorType::BATCHMATMUL:
case OperatorType::SCALAR_MULTIPLY:
case OperatorType::SCALAR_ADD:
case OperatorType::SCALAR_FLOOR_DIV:
case OperatorType::SCALAR_TRUE_DIV:
case OperatorType::SCALAR_SUB:
case OperatorType::RELU:
case OperatorType::IDENTITY:
case OperatorType::SIGMOID:
case OperatorType::TANH:
case OperatorType::ELU:
case OperatorType::FLAT:
case OperatorType::SOFTMAX:
case OperatorType::BATCHNORM:
case OperatorType::CONCAT:
case OperatorType::SPLIT:
Expand All @@ -96,7 +148,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
case OperatorType::RESHAPE:
case OperatorType::REVERSE:
case OperatorType::TRANSPOSE:
case OperatorType::EW_ADD:
case OperatorType::EW_MUL:
case OperatorType::MATMUL:
case OperatorType::MUL:
Expand Down Expand Up @@ -143,10 +194,6 @@ PCGOperatorAttrs materialize_operator_from_attrs_map(
case OperatorType::LAYERNORM:
case OperatorType::GATHER:
case OperatorType::BROADCAST:
case OperatorType::REPARTITION:
case OperatorType::COMBINE:
case OperatorType::REPLICATE:
case OperatorType::REDUCTION:
case OperatorType::BATCH:
case OperatorType::PIPELINE:
case OperatorType::FUSED_PARALLEL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ TensorAttributePattern tensor_attribute_pattern_match_all() {
}

TensorAttributePattern
tensor_attr_pattern_require_num_dims(nonnegative_int num_dims) {
tensor_attr_pattern_require_num_dims(positive_int num_dims) {
return TensorAttributePattern{{
TensorAttributeConstraint{
ConstraintType::EQUAL,
Expand Down
Loading
Loading