diff --git a/lib/Analysis/MulDepthAnalysis/BUILD b/lib/Analysis/MulDepthAnalysis/BUILD index d81225e5f0..0f712d642c 100644 --- a/lib/Analysis/MulDepthAnalysis/BUILD +++ b/lib/Analysis/MulDepthAnalysis/BUILD @@ -12,6 +12,7 @@ cc_library( deps = [ "@heir//lib/Analysis:Utils", "@heir//lib/Analysis/SecretnessAnalysis", + "@heir//lib/Dialect:HEIRInterfaces", "@heir//lib/Dialect/Mgmt/IR:Dialect", "@heir//lib/Dialect/Secret/IR:Dialect", "@llvm-project//llvm:Support", diff --git a/lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.cpp b/lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.cpp index 8998767439..d45fa37b0b 100644 --- a/lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.cpp +++ b/lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.cpp @@ -5,12 +5,10 @@ #include #include "lib/Analysis/Utils.h" -#include "lib/Dialect/Mgmt/IR/MgmtOps.h" +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Secret/IR/SecretOps.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project @@ -22,11 +20,6 @@ namespace heir { FailureOr deriveResultMulDepth( Operation* op, ArrayRef operands) { - auto isMul = false; - if (isa(op)) { - isMul = true; - } - int64_t operandsMulDepth = 0; for (auto* operand : operands) { if (!operand || !operand->getValue().isInitialized()) { @@ -36,7 +29,9 @@ FailureOr deriveResultMulDepth( std::max(operandsMulDepth, operand->getValue().getMulDepth()); } - return operandsMulDepth + (isMul ? 1 : 0); + int64_t increase = 0; + if (dyn_cast(op)) increase = 1; + return operandsMulDepth + increase; } LogicalResult MulDepthAnalysis::visitOperation( diff --git a/lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.h b/lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.h index 16e45a8a38..b2b9293110 100644 --- a/lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.h +++ b/lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.h @@ -7,12 +7,14 @@ #include #include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Secret/IR/SecretTypes.h" #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project @@ -86,7 +88,8 @@ class MulDepthAnalysis friend class SecretnessAnalysisDependent; void setToEntryState(MulDepthLattice* lattice) override { - if (isa(lattice->getAnchor().getType())) { + if (isa( + getElementTypeOrSelf(lattice->getAnchor().getType()))) { propagateIfChanged(lattice, lattice->join(MulDepthState(0))); return; } diff --git a/lib/Analysis/SecretnessAnalysis/BUILD b/lib/Analysis/SecretnessAnalysis/BUILD index d5b3fae846..81742b1669 100644 --- a/lib/Analysis/SecretnessAnalysis/BUILD +++ b/lib/Analysis/SecretnessAnalysis/BUILD @@ -12,6 +12,7 @@ cc_library( hdrs = ["SecretnessAnalysis.h"], deps = [ "@heir//lib/Analysis:Utils", + "@heir//lib/Dialect:HEIRInterfaces", "@heir//lib/Dialect/Secret/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", diff --git a/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp b/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp index e712738f5f..cd3d9ad9a4 100644 --- a/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp +++ b/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.cpp @@ -5,6 +5,7 @@ #include #include "lib/Analysis/Utils.h" +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Secret/IR/SecretDialect.h" #include "lib/Dialect/Secret/IR/SecretOps.h" #include "lib/Dialect/Secret/IR/SecretTypes.h" @@ -15,6 +16,8 @@ #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project @@ -26,7 +29,7 @@ namespace heir { void SecretnessAnalysis::setToEntryState(SecretnessLattice* lattice) { auto value = lattice->getAnchor(); - bool secretness = isa(value.getType()); + bool secretness = isa(value.getType()); auto blockArg = dyn_cast(value); Operation* operation = nullptr; @@ -48,9 +51,10 @@ void SecretnessAnalysis::setToEntryState(SecretnessLattice* lattice) { // check if the operand is either of secret type or annotated with // {secret.secret} if (auto funcOp = dyn_cast(*operation)) { - // Check if it has secret type - secretness = isa( - funcOp.getArgumentTypes()[blockArg.getArgNumber()]); + Type argType = funcOp.getArgumentTypes()[blockArg.getArgNumber()]; + // Check if it has secret-like type (secret, ciphertext, or shaped type + // thereof) + secretness = isa(getElementTypeOrSelf(argType)); // check if it is annotated as {secret.secret} UnitAttr attr = funcOp.getArgAttrOfType( diff --git a/lib/Dialect/BUILD b/lib/Dialect/BUILD index 09e78bb088..ca109885f0 100644 --- a/lib/Dialect/BUILD +++ b/lib/Dialect/BUILD @@ -14,11 +14,13 @@ cc_library( srcs = ["HEIRInterfaces.cpp"], hdrs = ["HEIRInterfaces.h"], deps = [ - ":interfaces_inc_gen", + ":op_interfaces_inc_gen", + ":type_interfaces_inc_gen", "@heir//lib/Dialect/Secret/IR:SecretAttributes", "@heir//lib/Transforms/LayoutOptimization:Hoisting", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], @@ -70,10 +72,23 @@ td_library( ) gentbl_cc_library( - name = "interfaces_inc_gen", + name = "op_interfaces_inc_gen", + tbl_outs = { + "HEIROpInterfaces.h.inc": ["-gen-op-interface-decls"], + "HEIROpInterfaces.cpp.inc": ["-gen-op-interface-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HEIRInterfaces.td", + deps = [ + ":td_files", + ], +) + +gentbl_cc_library( + name = "type_interfaces_inc_gen", tbl_outs = { - "HEIRInterfaces.h.inc": ["-gen-op-interface-decls"], - "HEIRInterfaces.cpp.inc": ["-gen-op-interface-defs"], + "HEIRTypeInterfaces.h.inc": ["-gen-type-interface-decls"], + "HEIRTypeInterfaces.cpp.inc": ["-gen-type-interface-defs"], }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "HEIRInterfaces.td", diff --git a/lib/Dialect/HEIRInterfaces.cpp b/lib/Dialect/HEIRInterfaces.cpp index a15c7c3061..2f866f45cf 100644 --- a/lib/Dialect/HEIRInterfaces.cpp +++ b/lib/Dialect/HEIRInterfaces.cpp @@ -6,18 +6,20 @@ #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { namespace heir { -#include "lib/Dialect/HEIRInterfaces.cpp.inc" +#include "lib/Dialect/HEIROpInterfaces.cpp.inc" +#include "lib/Dialect/HEIRTypeInterfaces.cpp.inc" void registerOperandAndResultAttrInterface(DialectRegistry& registry) { registry.addExtension(+[](MLIRContext* ctx, affine::AffineDialect* dialect) { @@ -25,6 +27,13 @@ void registerOperandAndResultAttrInterface(DialectRegistry& registry) { }); } +void registerIncreasesMulDepthOpInterface(DialectRegistry& registry) { + registry.addExtension(+[](MLIRContext* ctx, arith::ArithDialect* dialect) { + arith::MulIOp::attachInterface(*ctx); + arith::MulFOp::attachInterface(*ctx); + }); +} + LogicalResult verifyElementwiseByOperandImpl( ElementwiseByOperandOpInterface opInterface) { Operation* op = opInterface.getOperation(); diff --git a/lib/Dialect/HEIRInterfaces.h b/lib/Dialect/HEIRInterfaces.h index dcbb551f6d..ff8855f722 100644 --- a/lib/Dialect/HEIRInterfaces.h +++ b/lib/Dialect/HEIRInterfaces.h @@ -19,6 +19,7 @@ namespace heir { class ElementwiseByOperandOpInterface; void registerOperandAndResultAttrInterface(DialectRegistry& registry); +void registerIncreasesMulDepthOpInterface(DialectRegistry& registry); LogicalResult verifyElementwiseByOperandImpl( ElementwiseByOperandOpInterface op); @@ -27,7 +28,8 @@ LogicalResult verifyElementwiseByOperandImpl( } // namespace mlir // IWYU pragma: begin_keep -#include "lib/Dialect/HEIRInterfaces.h.inc" +#include "lib/Dialect/HEIROpInterfaces.h.inc" +#include "lib/Dialect/HEIRTypeInterfaces.h.inc" #include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project // IWYU pragma: end_keep diff --git a/lib/Dialect/HEIRInterfaces.td b/lib/Dialect/HEIRInterfaces.td index e55c03e897..3984847fb8 100644 --- a/lib/Dialect/HEIRInterfaces.td +++ b/lib/Dialect/HEIRInterfaces.td @@ -1,5 +1,31 @@ +#ifndef LIB_DIALECT_HEIR_INTERFACES_TD_ +#define LIB_DIALECT_HEIR_INTERFACES_TD_ + include "mlir/IR/OpBase.td" +def SecretTypeInterface : TypeInterface<"SecretTypeInterface"> { + let cppNamespace = "::mlir::heir"; + let description = [{ + This interface marks that a type represents a secret value. This + is an interface (beyond just the `secret.secret` type) so that + concrete ciphertext types can also be marked as secret and reuse + analyses that depend on `SecretnessAnalysis`. + }]; +} + +def IncreasesMulDepthOpInterface : OpInterface<"IncreasesMulDepthOpInterface"> { + let cppNamespace = "::mlir::heir"; + let description = [{ + A trait that signals whether an operation is a mul-like operation, used for + generalizing multiplicative depth analysis. + }]; + + // No methods; We declare an interface instead of a Trait because we need to + // attach the interface to upstream MLIR ops, and there is no way to do that + // with Traits because Traits are intended to be statically checkable. Cf. + // https://discourse.llvm.org/t/how-to-add-addtional-traits-to-existing-ops/62039 +} + def LUTOpInterface : OpInterface<"LUTOpInterface"> { let cppNamespace = "::mlir::heir"; let description = [{ @@ -361,3 +387,5 @@ def OperandAndResultAttrInterface : OpInterface<"OperandAndResultAttrInterface"> } }]; } + +#endif // LIB_DIALECT_HEIR_IR_HEIRINTERFACES_TD_ diff --git a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/BUILD b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/BUILD index 8e53be6963..1702e0aadb 100644 --- a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/BUILD +++ b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/BUILD @@ -17,11 +17,13 @@ cc_library( "@heir//lib/Dialect/BGV/IR:Dialect", "@heir//lib/Dialect/CKKS/IR:Dialect", "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Dialect/ModArith/IR:Dialect", "@heir//lib/Dialect/Openfhe/IR:Dialect", "@heir//lib/Parameters/CKKS:Params", "@heir//lib/Utils", "@heir//lib/Utils:ConversionUtils", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp index 6a958412aa..a9b45654af 100644 --- a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp +++ b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp @@ -18,8 +18,10 @@ #include "lib/Parameters/CKKS/Params.h" #include "lib/Utils/ConversionUtils.h" #include "lib/Utils/Utils.h" -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project @@ -37,19 +39,33 @@ #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project // IWYU pragma: end_keep +#define DEBUG_TYPE "lwe-to-openfhe" + namespace mlir::heir::lwe { #define GEN_PASS_DEF_LWETOOPENFHE #include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h.inc" +Type convertLWEType(Type type) { + return llvm::TypeSwitch(type) + .Case( + [&](auto ty) { return openfhe::PublicKeyType::get(ty.getContext()); }) + .Case([&](auto ty) { + return openfhe::PrivateKeyType::get(ty.getContext()); + }) + .Case( + [&](auto ty) { return openfhe::PlaintextType::get(ty.getContext()); }) + .Case([&](auto ty) { + return openfhe::CiphertextType::get(ty.getContext()); + }) + .Case([&](auto ty) { + return ty.clone(convertLWEType(ty.getElementType())); + }) + .Default([&](Type ty) { return ty; }); +} + ToOpenfheTypeConverter::ToOpenfheTypeConverter(MLIRContext* ctx) { - addConversion([](Type type) { return type; }); - addConversion([ctx](lwe::LWEPublicKeyType type) -> Type { - return openfhe::PublicKeyType::get(ctx); - }); - addConversion([ctx](lwe::LWESecretKeyType type) -> Type { - return openfhe::PrivateKeyType::get(ctx); - }); + addConversion([&](Type type) { return convertLWEType(type); }); } FailureOr getContextualCryptoContext(Operation* op) { @@ -146,9 +162,10 @@ struct ConvertEncryptOp : public OpConversionPattern { Value cryptoContext = result.value(); rewriter.replaceOp( - op, openfhe::EncryptOp::create(rewriter, op.getLoc(), - op.getOutput().getType(), cryptoContext, - adaptor.getInput(), adaptor.getKey())); + op, openfhe::EncryptOp::create( + rewriter, op.getLoc(), + openfhe::CiphertextType::get(op.getContext()), cryptoContext, + adaptor.getInput(), adaptor.getKey())); return success(); } }; @@ -167,9 +184,10 @@ struct ConvertDecryptOp : public OpConversionPattern { Value cryptoContext = result.value(); rewriter.replaceOp( - op, openfhe::DecryptOp::create( - rewriter, op.getLoc(), op.getOutput().getType(), cryptoContext, - adaptor.getInput(), adaptor.getSecretKey())); + op, + openfhe::DecryptOp::create( + rewriter, op.getLoc(), openfhe::PlaintextType::get(op.getContext()), + cryptoContext, adaptor.getInput(), adaptor.getSecretKey())); return success(); } }; @@ -232,7 +250,7 @@ struct ConvertEncodeOp : public OpConversionPattern { } } - lwe::LWEPlaintextType plaintextType = op.getResult().getType(); + auto plaintextType = openfhe::PlaintextType::get(op.getContext()); return llvm::TypeSwitch(op.getEncoding()) .Case([&](auto encoding) { rewriter.replaceOpWithNewOp( @@ -264,6 +282,45 @@ struct ConvertEncodeOp : public OpConversionPattern { } }; +struct ConvertDecodeOp : public OpConversionPattern { + explicit ConvertDecodeOp(const mlir::TypeConverter& typeConverter, + mlir::MLIRContext* context) + : mlir::OpConversionPattern(typeConverter, context) {} + + LogicalResult matchAndRewrite( + lwe::RLWEDecodeOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + return llvm::TypeSwitch(op.getEncoding()) + .Case([&](auto encoding) { + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), adaptor.getInput()); + return success(); + }) + .Case([&](auto encoding) { + // TODO (#1192): support coefficient packing in `--lwe-to-openfhe` + op.emitError() << "HEIR does not yet support coefficient encoding " + " when targeting OpenFHE"; + return rewriter.notifyMatchFailure( + op, + "HEIR does not yet support coefficient encoding when targeting " + "OpenFHE"); + }) + .Case([&](auto encoding) { + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), adaptor.getInput()); + return success(); + }) + .Default([&](Attribute) -> LogicalResult { + // encoding isn't support explicitly: + op.emitError( + "Unexpected encoding while targeting OpenFHE. " + "If you expect this type of encoding to be supported " + "for the OpenFHE backend, please file a bug report."); + return rewriter.notifyMatchFailure(op, "Unknown encoding"); + }); + } +}; + struct ConvertBootstrapOp : public OpConversionPattern { ConvertBootstrapOp(mlir::MLIRContext* context) : OpConversionPattern(context) {} @@ -298,6 +355,18 @@ struct ConvertBootstrapOp : public OpConversionPattern { return success(); } }; + +struct EraseLWEReinterpretApplicationData + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + lwe::ReinterpretApplicationDataOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOp(op, adaptor.getOperands()[0]); + return success(); + } +}; } // namespace struct LWEToOpenfhe : public impl::LWEToOpenfheBase { @@ -345,8 +414,6 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase { target.addIllegalDialect(); target.addIllegalDialect(); target.addIllegalDialect(); - // We can keep the following ops, which the emitter can handle directly - target.addLegalOp(); RewritePatternSet patterns(context); addStructuralConversionPatterns(typeConverter, patterns, target); @@ -391,9 +458,9 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase { // Update Func CallOp Signature ConvertFuncCallOp, - // Handle LWE encode and en/decrypt - // Note: `lwe.decode` is handled directly by the OpenFHE emitter - ConvertEncodeOp, ConvertEncryptOp, ConvertDecryptOp, + // Encoding and encryption + ConvertEncodeOp, ConvertDecodeOp, ConvertEncryptOp, ConvertDecryptOp, + EraseLWEReinterpretApplicationData, // Scheme-agnostic RLWE Arithmetic Ops: ConvertLWEBinOp, @@ -425,6 +492,13 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase { ConvertBootstrapOp>(typeConverter, context); ConversionConfig config; + // We need allowPatternRollback here because failure to legalize an op + // (like a relinearize op with an invalid basis, as tested in invalid.mlir) + // is then processed by ConvertAny<>, and when that fails to legalize, the + // hard error makes it so --verify-diagnostics cannot be applied, and + // in turn lit tests break. Seems annoying to fix the lit tests (pipe stderr + // to stdout and then FileCheck on the combined stream? instead of + // --verify-diagnostics) config.allowPatternRollback = false; if (failed(applyPartialConversion(module, target, std::move(patterns), config))) { diff --git a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h index 6273131ab0..4ee2901457 100644 --- a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h +++ b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h @@ -6,6 +6,7 @@ // IWYU pragma: end_keep #include "lib/Dialect/Openfhe/IR/OpenfheOps.h" +#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project @@ -55,9 +56,9 @@ struct ConvertLWEBinOp : public OpConversionPattern { if (failed(result)) return result; Value cryptoContext = result.value(); - rewriter.replaceOpWithNewOp(op, op.getOutput().getType(), - cryptoContext, adaptor.getLhs(), - adaptor.getRhs()); + rewriter.replaceOpWithNewOp( + op, openfhe::CiphertextType::get(op.getContext()), cryptoContext, + adaptor.getLhs(), adaptor.getRhs()); return success(); } }; @@ -73,9 +74,9 @@ struct ConvertCiphertextPlaintextOp : public OpConversionPattern { if (failed(result)) return result; Value cryptoContext = result.value(); - rewriter.replaceOpWithNewOp(op, op.getOutput().getType(), - cryptoContext, adaptor.getLhs(), - adaptor.getRhs()); + rewriter.replaceOpWithNewOp( + op, openfhe::CiphertextType::get(op.getContext()), cryptoContext, + adaptor.getLhs(), adaptor.getRhs()); return success(); } }; @@ -118,6 +119,7 @@ struct ConvertRelinOp : public OpConversionPattern { ConversionPatternRewriter& rewriter) const override { FailureOr result = getContextualCryptoContext(op.getOperation()); if (failed(result)) return result; + Value cryptoContext = result.value(); auto toBasis = adaptor.getToBasis(); @@ -128,10 +130,9 @@ struct ConvertRelinOp : public OpConversionPattern { op.emitError() << "toBasis must be [0, 1], got [" << toBasis << "]"; return failure(); } - - Value cryptoContext = result.value(); - rewriter.replaceOpWithNewOp(op, op.getOutput().getType(), - cryptoContext, adaptor.getInput()); + rewriter.replaceOpWithNewOp( + op, openfhe::CiphertextType::get(op.getContext()), cryptoContext, + adaptor.getInput()); return success(); } }; @@ -153,7 +154,8 @@ struct ConvertModulusSwitchOp : public OpConversionPattern { Value cryptoContext = result.value(); rewriter.replaceOp(op, openfhe::ModReduceOp::create( - rewriter, op.getLoc(), op.getOutput().getType(), + rewriter, op.getLoc(), + openfhe::CiphertextType::get(op.getContext()), cryptoContext, adaptor.getInput())); return success(); } @@ -175,7 +177,8 @@ struct ConvertLevelReduceOp : public OpConversionPattern { Value cryptoContext = result.value(); rewriter.replaceOp( op, openfhe::LevelReduceOp::create( - rewriter, op.getLoc(), op.getOutput().getType(), cryptoContext, + rewriter, op.getLoc(), + openfhe::CiphertextType::get(op.getContext()), cryptoContext, adaptor.getInput(), op.getLevelToDrop())); return success(); } diff --git a/lib/Dialect/LWE/IR/BUILD b/lib/Dialect/LWE/IR/BUILD index 56025495d5..b72d2a9dd3 100644 --- a/lib/Dialect/LWE/IR/BUILD +++ b/lib/Dialect/LWE/IR/BUILD @@ -29,6 +29,7 @@ cc_library( ":enums_inc_gen", ":ops_inc_gen", ":types_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", "@heir//lib/Dialect/ModArith/IR:Dialect", "@heir//lib/Dialect/Polynomial/IR:Dialect", "@heir//lib/Dialect/RNS/IR:Dialect", @@ -69,6 +70,7 @@ td_library( # include from the heir-root to enable fully-qualified include-paths includes = ["../../../.."], deps = [ + "@heir//lib/Dialect:td_files", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", diff --git a/lib/Dialect/LWE/IR/LWEOps.h b/lib/Dialect/LWE/IR/LWEOps.h index cc77c3ea3f..4e1b3a2194 100644 --- a/lib/Dialect/LWE/IR/LWEOps.h +++ b/lib/Dialect/LWE/IR/LWEOps.h @@ -16,6 +16,7 @@ #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project // IWYU pragma: begin_keep +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/LWE/IR/LWEDialect.h" #include "lib/Dialect/LWE/IR/LWETraits.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project diff --git a/lib/Dialect/LWE/IR/LWEOps.td b/lib/Dialect/LWE/IR/LWEOps.td index 8ca6d28d33..01d6fb1070 100644 --- a/lib/Dialect/LWE/IR/LWEOps.td +++ b/lib/Dialect/LWE/IR/LWEOps.td @@ -1,6 +1,7 @@ #ifndef LIB_DIALECT_LWE_IR_LWEOPS_TD_ #define LIB_DIALECT_LWE_IR_LWEOPS_TD_ +include "lib/Dialect/HEIRInterfaces.td" include "lib/Dialect/LWE/IR/LWEDialect.td" include "lib/Dialect/LWE/IR/LWETraits.td" include "lib/Dialect/LWE/IR/LWETypes.td" @@ -160,7 +161,7 @@ def LWE_RSubPlainOp : LWE_CiphertextPlaintextOp<"rsub_plain", [SameOperandsAndRe let summary = "Subtraction between RLWE ciphertext-plaintext"; } -def LWE_RMulOp : LWE_BinOp<"rmul", [SameOperandsAndResultRings, InferTypeOpAdaptor, Commutative]> { +def LWE_RMulOp : LWE_BinOp<"rmul", [SameOperandsAndResultRings, InferTypeOpAdaptor, Commutative, IncreasesMulDepthOpInterface]> { let summary = "Multiplies two RLWE ciphertexts"; let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) @@ -168,7 +169,7 @@ def LWE_RMulOp : LWE_BinOp<"rmul", [SameOperandsAndResultRings, InferTypeOpAdapt let hasVerifier = 1; } -def LWE_RMulPlainOp : LWE_CiphertextPlaintextOp<"rmul_plain", [Commutative]> { +def LWE_RMulPlainOp : LWE_CiphertextPlaintextOp<"rmul_plain", [Commutative, IncreasesMulDepthOpInterface]> { let summary = "Multiplication between RLWE ciphertext-plaintext"; let hasVerifier = 1; let hasCanonicalizer = 1; diff --git a/lib/Dialect/Lattigo/IR/BUILD b/lib/Dialect/Lattigo/IR/BUILD index 6ca4637434..696054457e 100644 --- a/lib/Dialect/Lattigo/IR/BUILD +++ b/lib/Dialect/Lattigo/IR/BUILD @@ -28,6 +28,7 @@ cc_library( ":LattigoAttributes", ":LattigoOps", ":LattigoTypes", + "@heir//lib/Dialect:HEIRInterfaces", "@heir//lib/Utils/Tablegen:InplaceOpInterface", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -65,6 +66,7 @@ cc_library( ":attributes_inc_gen", ":dialect_inc_gen", ":types_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", "@llvm-project//mlir:IR", ], ) @@ -85,6 +87,7 @@ cc_library( ":dialect_inc_gen", ":ops_inc_gen", ":types_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", "@heir//lib/Utils/Tablegen:InplaceOpInterface", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -110,6 +113,7 @@ td_library( # include from the heir-root to enable fully-qualified include-paths includes = ["../../../.."], deps = [ + "@heir//lib/Dialect:td_files", "@heir//lib/Utils/Tablegen:td_files", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:OpBaseTdFiles", diff --git a/lib/Dialect/Lattigo/IR/LattigoBGVOps.td b/lib/Dialect/Lattigo/IR/LattigoBGVOps.td index 26956863ee..4989e5d741 100644 --- a/lib/Dialect/Lattigo/IR/LattigoBGVOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoBGVOps.td @@ -3,6 +3,7 @@ include "LattigoDialect.td" include "LattigoTypes.td" +include "lib/Dialect/HEIRInterfaces.td" include "mlir/IR/OpBase.td" class Lattigo_BGVOp traits = []> : @@ -105,8 +106,8 @@ def Lattigo_BGVNewEvaluatorOp : Lattigo_BGVOp<"new_evaluator"> { // ciphertext arithmetic op -class Lattigo_BGVBinaryOp : - Lattigo_BGVOp { +class Lattigo_BGVBinaryOp traits = []> : + Lattigo_BGVOp { let arguments = (ins Lattigo_BGVEvaluator:$evaluator, Lattigo_RLWECiphertext:$lhs, @@ -129,15 +130,15 @@ def Lattigo_BGVSubNewOp : Lattigo_BGVBinaryOp<"sub_new"> { }]; } -def Lattigo_BGVMulNewOp : Lattigo_BGVBinaryOp<"mul_new"> { +def Lattigo_BGVMulNewOp : Lattigo_BGVBinaryOp<"mul_new", [IncreasesMulDepthOpInterface]> { let summary = "Multiply two ciphertexts in the Lattigo BGV dialect"; let description = [{ This operation multiplies two ciphertext values in the Lattigo BGV dialect. }]; } -class Lattigo_BGVBinaryInplaceOp : - Lattigo_BGVOp { +class Lattigo_BGVBinaryInplaceOp traits = []> : + Lattigo_BGVOp { let arguments = (ins Lattigo_BGVEvaluator:$evaluator, Lattigo_RLWECiphertext:$lhs, @@ -171,7 +172,7 @@ def Lattigo_BGVSubOp : Lattigo_BGVBinaryInplaceOp<"sub"> { }]; } -def Lattigo_BGVMulOp : Lattigo_BGVBinaryInplaceOp<"mul"> { +def Lattigo_BGVMulOp : Lattigo_BGVBinaryInplaceOp<"mul", [IncreasesMulDepthOpInterface]> { let summary = "Multiply two ciphertexts in the Lattigo BGV dialect"; let description = [{ This operation multiplies two ciphertext values in the Lattigo BGV dialect. diff --git a/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td b/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td index e12d2714cc..025279f3b1 100644 --- a/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoCKKSOps.td @@ -3,6 +3,7 @@ include "LattigoDialect.td" include "LattigoTypes.td" +include "lib/Dialect/HEIRInterfaces.td" include "mlir/IR/OpBase.td" class Lattigo_CKKSOp traits = []> : @@ -99,8 +100,8 @@ def Lattigo_CKKSNewEvaluatorOp : Lattigo_CKKSOp<"new_evaluator"> { // ciphertext arithmetic op -class Lattigo_CKKSBinaryOp : - Lattigo_CKKSOp { +class Lattigo_CKKSBinaryOp traits = []> : + Lattigo_CKKSOp { let arguments = (ins Lattigo_CKKSEvaluator:$evaluator, Lattigo_RLWECiphertext:$lhs, @@ -123,15 +124,15 @@ def Lattigo_CKKSSubNewOp : Lattigo_CKKSBinaryOp<"sub_new"> { }]; } -def Lattigo_CKKSMulNewOp : Lattigo_CKKSBinaryOp<"mul_new"> { +def Lattigo_CKKSMulNewOp : Lattigo_CKKSBinaryOp<"mul_new", [IncreasesMulDepthOpInterface]> { let summary = "Multiply two ciphertexts in the Lattigo CKKS dialect"; let description = [{ This operation multiplies two ciphertext values in the Lattigo CKKS dialect. }]; } -class Lattigo_CKKSBinaryInplaceOp : - Lattigo_CKKSOp { +class Lattigo_CKKSBinaryInplaceOp traits = []> : + Lattigo_CKKSOp { let arguments = (ins Lattigo_CKKSEvaluator:$evaluator, Lattigo_RLWECiphertext:$lhs, @@ -165,7 +166,7 @@ def Lattigo_CKKSSubOp : Lattigo_CKKSBinaryInplaceOp<"sub"> { }]; } -def Lattigo_CKKSMulOp : Lattigo_CKKSBinaryInplaceOp<"mul"> { +def Lattigo_CKKSMulOp : Lattigo_CKKSBinaryInplaceOp<"mul", [IncreasesMulDepthOpInterface]> { let summary = "Multiply two ciphertexts in the Lattigo CKKS dialect"; let description = [{ This operation multiplies two ciphertext values in the Lattigo CKKS dialect. diff --git a/lib/Dialect/Lattigo/IR/LattigoOps.h b/lib/Dialect/Lattigo/IR/LattigoOps.h index b49c5be447..50f4ed6f44 100644 --- a/lib/Dialect/Lattigo/IR/LattigoOps.h +++ b/lib/Dialect/Lattigo/IR/LattigoOps.h @@ -1,10 +1,13 @@ #ifndef LIB_DIALECT_LATTIGO_IR_LATTIGOOPS_H_ #define LIB_DIALECT_LATTIGO_IR_LATTIGOOPS_H_ +// IWYU pragma: begin_keep +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Lattigo/IR/LattigoDialect.h" #include "lib/Dialect/Lattigo/IR/LattigoTypes.h" #include "lib/Utils/Tablegen/InplaceOpInterface.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +// IWYU pragma: end_keep #define GET_OP_CLASSES #include "lib/Dialect/Lattigo/IR/LattigoOps.h.inc" diff --git a/lib/Dialect/Lattigo/IR/LattigoRLWETypes.td b/lib/Dialect/Lattigo/IR/LattigoRLWETypes.td index 36b446df1c..886facd105 100644 --- a/lib/Dialect/Lattigo/IR/LattigoRLWETypes.td +++ b/lib/Dialect/Lattigo/IR/LattigoRLWETypes.td @@ -3,11 +3,12 @@ include "LattigoAttributes.td" +include "lib/Dialect/HEIRInterfaces.td" include "mlir/IR/DialectBase.td" include "mlir/IR/AttrTypeBase.td" -class Lattigo_RLWEType - : Lattigo_Type<"RLWE" # name, "rlwe." # typeMnemonic> { +class Lattigo_RLWEType traits = []> + : Lattigo_Type<"RLWE" # name, "rlwe." # typeMnemonic, traits> { } def Lattigo_RLWEKeyGenerator : Lattigo_RLWEType<"KeyGenerator", "key_generator"> { @@ -91,7 +92,7 @@ def Lattigo_RLWEPlaintext : Lattigo_RLWEType<"Plaintext", "plaintext"> { let asmName = "pt"; } -def Lattigo_RLWECiphertext : Lattigo_RLWEType<"Ciphertext", "ciphertext"> { +def Lattigo_RLWECiphertext : Lattigo_RLWEType<"Ciphertext", "ciphertext", [SecretTypeInterface]> { let description = [{ This type represents the ciphertext for the RLWE encryption scheme. }]; diff --git a/lib/Dialect/Lattigo/IR/LattigoTypes.h b/lib/Dialect/Lattigo/IR/LattigoTypes.h index 2959bd065e..9f2668ef3a 100644 --- a/lib/Dialect/Lattigo/IR/LattigoTypes.h +++ b/lib/Dialect/Lattigo/IR/LattigoTypes.h @@ -1,9 +1,12 @@ #ifndef LIB_DIALECT_LATTIGO_IR_LATTIGOTYPES_H_ #define LIB_DIALECT_LATTIGO_IR_LATTIGOTYPES_H_ +// IWYU pragma: begin_keep +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Lattigo/IR/LattigoAttributes.h" #include "lib/Dialect/Lattigo/IR/LattigoDialect.h" #include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project +// IWYU pragma: end_keep #define GET_TYPEDEF_CLASSES #include "lib/Dialect/Lattigo/IR/LattigoTypes.h.inc" diff --git a/lib/Dialect/Mgmt/IR/BUILD b/lib/Dialect/Mgmt/IR/BUILD index adfe5ff327..1081e9f18b 100644 --- a/lib/Dialect/Mgmt/IR/BUILD +++ b/lib/Dialect/Mgmt/IR/BUILD @@ -25,6 +25,7 @@ cc_library( "ops_inc_gen", ":MgmtAttributes", ":MgmtOps", + "@heir//lib/Dialect:HEIRInterfaces", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", @@ -66,6 +67,7 @@ cc_library( ":canonicalize_inc_gen", ":dialect_inc_gen", ":ops_inc_gen", + "@heir//lib/Dialect:HEIRInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", ], @@ -82,6 +84,7 @@ td_library( # include from the heir-root to enable fully-qualified include-paths includes = ["../../../.."], deps = [ + "@heir//lib/Dialect:td_files", "@heir//lib/Utils/DRR", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", diff --git a/lib/Dialect/Mgmt/IR/MgmtOps.h b/lib/Dialect/Mgmt/IR/MgmtOps.h index cecf15aba7..315a353b2f 100644 --- a/lib/Dialect/Mgmt/IR/MgmtOps.h +++ b/lib/Dialect/Mgmt/IR/MgmtOps.h @@ -1,13 +1,13 @@ #ifndef LIB_DIALECT_MGMT_IR_MGMTOPS_H_ #define LIB_DIALECT_MGMT_IR_MGMTOPS_H_ -// NOLINTBEGIN(misc-include-cleaner): Required to define MgmtOps +// IWYU pragma: begin_keep +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Mgmt/IR/MgmtDialect.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -// NOLINTEND(misc-include-cleaner) - -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +// IWYU pragma: end_keep #define GET_OP_CLASSES #include "lib/Dialect/Mgmt/IR/MgmtOps.h.inc" diff --git a/lib/Dialect/Mgmt/IR/MgmtOps.td b/lib/Dialect/Mgmt/IR/MgmtOps.td index 51c3d2ae5b..d0fd8e4699 100644 --- a/lib/Dialect/Mgmt/IR/MgmtOps.td +++ b/lib/Dialect/Mgmt/IR/MgmtOps.td @@ -3,14 +3,15 @@ include "lib/Dialect/Mgmt/IR/MgmtDialect.td" +include "lib/Dialect/HEIRInterfaces.td" include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/CommonTypeConstraints.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -class Mgmt_Op traits = [Pure, ElementwiseMappable, SameOperandsAndResultType]> : - Op { +class Mgmt_Op traits = []> : + Op { let cppNamespace = "::mlir::heir::mgmt"; } @@ -113,7 +114,7 @@ def Mgmt_BootstrapOp : Mgmt_Op<"bootstrap"> { let assemblyFormat = "operands attr-dict `:` type($output)"; } -def Mgmt_AdjustScaleOp : Mgmt_Op<"adjust_scale"> { +def Mgmt_AdjustScaleOp : Mgmt_Op<"adjust_scale", [IncreasesMulDepthOpInterface]> { let summary = "Adjust the scale of the input ciphertext (for BGV and CKKS)"; let description = [{ diff --git a/lib/Dialect/Openfhe/IR/BUILD b/lib/Dialect/Openfhe/IR/BUILD index 322e410606..a09d398785 100644 --- a/lib/Dialect/Openfhe/IR/BUILD +++ b/lib/Dialect/Openfhe/IR/BUILD @@ -22,7 +22,7 @@ cc_library( ":dialect_inc_gen", ":ops_inc_gen", ":types_inc_gen", - "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Dialect:HEIRInterfaces", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", @@ -39,6 +39,7 @@ td_library( ], includes = ["../../../.."], deps = [ + "@heir//lib/Dialect:td_files", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", @@ -72,6 +73,5 @@ add_heir_dialect_library( td_file = "OpenfheOps.td", deps = [ ":td_files", - "@heir//lib/Dialect/LWE/IR:td_files", ], ) diff --git a/lib/Dialect/Openfhe/IR/OpenfheOps.cpp b/lib/Dialect/Openfhe/IR/OpenfheOps.cpp index bbf369b16e..10013ce084 100644 --- a/lib/Dialect/Openfhe/IR/OpenfheOps.cpp +++ b/lib/Dialect/Openfhe/IR/OpenfheOps.cpp @@ -2,10 +2,7 @@ #include -#include "lib/Dialect/LWE/IR/LWEAttributes.h" -#include "lib/Dialect/LWE/IR/LWEOps.h" -#include "lib/Dialect/Openfhe/IR/OpenfheOps.h" -#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project +#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" #include "mlir/include/mlir/IR/Location.h" // from @llvm-project #include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project @@ -14,64 +11,6 @@ namespace mlir { namespace heir { -namespace openfhe { - -//===----------------------------------------------------------------------===// -// Op verifiers -//===----------------------------------------------------------------------===// - -LogicalResult MulNoRelinOp::verify() { return lwe::verifyMulOp(this); } - -LogicalResult MakePackedPlaintextOp::verify() { - auto enc = this->getPlaintext().getType().getPlaintextSpace().getEncoding(); - if (!llvm::isa(enc)) { - return emitOpError("plaintext type should use full_crt_packing_encoding."); - } - return success(); -} - -LogicalResult MakeCKKSPackedPlaintextOp::verify() { - auto enc = this->getPlaintext().getType().getPlaintextSpace().getEncoding(); - if (!llvm::isa(enc)) { - return emitOpError("plaintext type should use inverse_canonical_encoding."); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Op type inference. -//===----------------------------------------------------------------------===// - -LogicalResult AddOp::inferReturnTypes( - MLIRContext* ctx, std::optional, AddOp::Adaptor adaptor, - SmallVectorImpl& inferredReturnTypes) { - return lwe::inferAddOpReturnTypes(ctx, adaptor, inferredReturnTypes); -} - -LogicalResult AddPlainOp::inferReturnTypes( - MLIRContext* ctx, std::optional, AddPlainOp::Adaptor adaptor, - SmallVectorImpl& inferredReturnTypes) { - return lwe::inferPlainOpReturnTypes(ctx, adaptor, inferredReturnTypes); -} - -LogicalResult SubOp::inferReturnTypes( - MLIRContext* ctx, std::optional, SubOp::Adaptor adaptor, - SmallVectorImpl& inferredReturnTypes) { - return lwe::inferAddOpReturnTypes(ctx, adaptor, inferredReturnTypes); -} - -LogicalResult SubPlainOp::inferReturnTypes( - MLIRContext* ctx, std::optional, SubPlainOp::Adaptor adaptor, - SmallVectorImpl& inferredReturnTypes) { - return lwe::inferPlainOpReturnTypes(ctx, adaptor, inferredReturnTypes); -} - -LogicalResult MulNoRelinOp::inferReturnTypes( - MLIRContext* ctx, std::optional, MulNoRelinOp::Adaptor adaptor, - SmallVectorImpl& inferredReturnTypes) { - return lwe::inferMulOpReturnTypes(ctx, adaptor, inferredReturnTypes); -} - -} // namespace openfhe +namespace openfhe {} // namespace openfhe } // namespace heir } // namespace mlir diff --git a/lib/Dialect/Openfhe/IR/OpenfheOps.h b/lib/Dialect/Openfhe/IR/OpenfheOps.h index ab34042fa9..69e4bfa703 100644 --- a/lib/Dialect/Openfhe/IR/OpenfheOps.h +++ b/lib/Dialect/Openfhe/IR/OpenfheOps.h @@ -1,14 +1,15 @@ #ifndef LIB_DIALECT_OPENFHE_IR_OPENFHEOPS_H_ #define LIB_DIALECT_OPENFHE_IR_OPENFHEOPS_H_ -#include "lib/Dialect/LWE/IR/LWETraits.h" -#include "lib/Dialect/LWE/IR/LWETypes.h" +// IWYU pragma: begin_keep +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Openfhe/IR/OpenfheDialect.h" #include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project #include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +// IWYU pragma: end_keep #define GET_OP_CLASSES #include "lib/Dialect/Openfhe/IR/OpenfheOps.h.inc" diff --git a/lib/Dialect/Openfhe/IR/OpenfheOps.td b/lib/Dialect/Openfhe/IR/OpenfheOps.td index 09e87f54dc..14aa0b9caf 100644 --- a/lib/Dialect/Openfhe/IR/OpenfheOps.td +++ b/lib/Dialect/Openfhe/IR/OpenfheOps.td @@ -4,8 +4,7 @@ include "OpenfheDialect.td" include "OpenfheTypes.td" -include "lib/Dialect/LWE/IR/LWETypes.td" -include "lib/Dialect/LWE/IR/LWETraits.td" +include "lib/Dialect/HEIRInterfaces.td" include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/CommonTypeConstraints.td" include "mlir/IR/OpBase.td" @@ -20,49 +19,27 @@ class Openfhe_Op traits = []> : let cppNamespace = "::mlir::heir::openfhe"; } -class Openfhe_UnaryTypeSwitchOp traits = []> - : Openfhe_Op{ +class Openfhe_UnaryOp traits = []> + : Openfhe_Op]>{ let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$ciphertext + Openfhe_Ciphertext:$ciphertext ); - let results = (outs LWECiphertext:$output); + let results = (outs Openfhe_Ciphertext:$output); } -class Openfhe_UnaryOp traits = []> - : Openfhe_UnaryTypeSwitchOp< - mnemonic, traits # [AllTypesMatch<["ciphertext", "output"]>] - >; - class Openfhe_BinaryOp traits = []> : Openfhe_Op ]>{ let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$lhs, - LWECiphertext:$rhs - ); - let results = (outs LWECiphertext:$output); -} - -class Openfhe_BinaryInPlaceOp traits = []> - : Openfhe_Op, - ]> { - - let summary = "In-place binary operation for OpenFHE"; - - let arguments = (ins - Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$lhs, - LWECiphertext:$rhs + Openfhe_Ciphertext:$lhs, + Openfhe_Ciphertext:$rhs ); + let results = (outs Openfhe_Ciphertext:$output); } - def GenParamsOp : Openfhe_Op<"gen_params"> { let description = [{ Generates the parameters for the OpenFHE scheme. @@ -148,112 +125,88 @@ def GenBootstrapKeyOp : Openfhe_Op<"gen_bootstrapkey"> { def MakePackedPlaintextOp : Openfhe_Op<"make_packed_plaintext", [Pure]> { let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - RankedTensorOf<[AnyInteger]>:$value) - ; - let results = (outs LWEPlaintext:$plaintext); - let hasVerifier = 1; + RankedTensorOf<[AnyInteger]>:$value + ); + let results = (outs Openfhe_Plaintext:$plaintext); } def MakeCKKSPackedPlaintextOp : Openfhe_Op<"make_ckks_packed_plaintext", [Pure]> { let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - RankedTensorOf<[AnyFloat, AnyInteger]>:$value) - ; - let results = (outs LWEPlaintext:$plaintext); - let hasVerifier = 1; + RankedTensorOf<[AnyFloat, AnyInteger]>:$value + ); + let results = (outs Openfhe_Plaintext:$plaintext); } def EncryptOp : Openfhe_Op<"encrypt", [Pure]> { let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWEPlaintext:$plaintext, - Openfhe_PublicKeyOrPrivateKey:$encryptionKey) - ; - let results = (outs LWECiphertext:$ciphertext); + Openfhe_Plaintext:$plaintext, + Openfhe_PublicKeyOrPrivateKey:$encryptionKey + ); + let results = (outs Openfhe_Ciphertext:$ciphertext); } def DecryptOp : Openfhe_Op<"decrypt", [Pure]> { let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$ciphertext, - Openfhe_PrivateKey:$privateKey) - ; - let results = (outs LWEPlaintext:$plaintext); + Openfhe_Ciphertext:$ciphertext, + Openfhe_PrivateKey:$privateKey + ); + let results = (outs Openfhe_Plaintext:$plaintext); } -def AddOp : Openfhe_BinaryOp<"add", - [SameOperandsAndResultRings, - InferTypeOpAdaptor]> { +def AddOp : Openfhe_BinaryOp<"add"> { let summary = "OpenFHE add operation of two ciphertexts."; } -def SubOp : Openfhe_BinaryOp<"sub", - [SameOperandsAndResultRings, - InferTypeOpAdaptor]> { +def SubOp : Openfhe_BinaryOp<"sub"> { let summary = "OpenFHE sub operation of two ciphertexts."; } -// In-Place Addition -def AddInPlaceOp : Openfhe_BinaryInPlaceOp<"add_inplace"> { - let summary = "Performs in-place homomorphic addition, modifying lhs."; -} - -// In-Place Subtraction -def SubInPlaceOp : Openfhe_BinaryInPlaceOp<"sub_inplace"> { - let summary = "Performs in-place homomorphic subtraction, modifying lhs."; -} - +def Openfhe_PlaintextOrCiphertext : AnyTypeOf<[Openfhe_Ciphertext, Openfhe_Plaintext]>; -def AddPlainOp : Openfhe_Op<"add_plain",[ - Pure, - AllCiphertextTypesMatch, - InferTypeOpAdaptor -]> { +def AddPlainOp : Openfhe_Op<"add_plain", [Pure]> { let summary = "OpenFHE add operation of a ciphertext and a plaintext."; let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWEPlaintextOrCiphertext:$lhs, - LWEPlaintextOrCiphertext:$rhs + Openfhe_PlaintextOrCiphertext:$lhs, + Openfhe_PlaintextOrCiphertext:$rhs ); - let results = (outs LWECiphertext:$output); + let results = (outs Openfhe_Ciphertext:$output); } -def SubPlainOp : Openfhe_Op<"sub_plain",[ - Pure, - AllCiphertextTypesMatch, - InferTypeOpAdaptor -]> { +def SubPlainOp : Openfhe_Op<"sub_plain", [Pure]> { let summary = "OpenFHE sub operation of a ciphertext and a plaintext."; let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWEPlaintextOrCiphertext:$lhs, - LWEPlaintextOrCiphertext:$rhs + Openfhe_PlaintextOrCiphertext:$lhs, + Openfhe_PlaintextOrCiphertext:$rhs ); - let results = (outs LWECiphertext:$output); + let results = (outs Openfhe_Ciphertext:$output); } -def MulOp : Openfhe_BinaryOp<"mul"> { let summary = "OpenFHE mul operation of two ciphertexts with relinearization."; } +def MulOp : Openfhe_BinaryOp<"mul", [IncreasesMulDepthOpInterface]> { + let summary = "OpenFHE mul operation of two ciphertexts with relinearization."; +} -def MulNoRelinOp : Openfhe_Op<"mul_no_relin", [Pure, SameOperandsAndResultRings, InferTypeOpAdaptor]> { +def MulNoRelinOp : Openfhe_Op<"mul_no_relin", [Pure, IncreasesMulDepthOpInterface]> { let summary = "OpenFHE mul operation of two ciphertexts without relinearization."; let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$lhs, - LWECiphertext:$rhs + Openfhe_Ciphertext:$lhs, + Openfhe_Ciphertext:$rhs ); - let results = (outs LWECiphertext:$output); - let hasVerifier = 1; + let results = (outs Openfhe_Ciphertext:$output); } -def MulPlainOp : Openfhe_Op<"mul_plain",[ - Pure -]> { +def MulPlainOp : Openfhe_Op<"mul_plain", [Pure, IncreasesMulDepthOpInterface]> { let summary = "OpenFHE mul operation of a ciphertext and a plaintext."; let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$ciphertext, - LWEPlaintext:$plaintext + Openfhe_Ciphertext:$ciphertext, + Openfhe_Plaintext:$plaintext ); - let results = (outs LWECiphertext:$output); + let results = (outs Openfhe_Ciphertext:$output); } def MulConstOp : Openfhe_Op<"mul_const",[ @@ -263,22 +216,22 @@ def MulConstOp : Openfhe_Op<"mul_const",[ let summary = "OpenFHE mul operation of a ciphertext and a constant."; let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$ciphertext, + Openfhe_Ciphertext:$ciphertext, I64:$constant ); - let results = (outs LWECiphertext:$output); + let results = (outs Openfhe_Ciphertext:$output); } def NegateOp : Openfhe_UnaryOp<"negate"> { let summary = "OpenFHE negate operation of a ciphertext."; } def SquareOp : Openfhe_UnaryOp<"square"> { let summary = "OpenFHE square operation of a ciphertext."; } -def RelinOp : Openfhe_UnaryTypeSwitchOp<"relin"> { let summary = "OpenFHE relinearize operation of a ciphertext."; } +def RelinOp : Openfhe_UnaryOp<"relin"> { let summary = "OpenFHE relinearize operation of a ciphertext."; } -def ModReduceOp : Openfhe_UnaryTypeSwitchOp<"mod_reduce"> { let summary = "OpenFHE mod_reduce operation of a ciphertext. (used only for BGV/CKKS)"; } -def LevelReduceOp : Openfhe_UnaryTypeSwitchOp<"level_reduce"> { +def ModReduceOp : Openfhe_UnaryOp<"mod_reduce"> { let summary = "OpenFHE mod_reduce operation of a ciphertext. (used only for BGV/CKKS)"; } +def LevelReduceOp : Openfhe_UnaryOp<"level_reduce"> { let summary = "OpenFHE level_reduce operation of a ciphertext."; let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$ciphertext, + Openfhe_Ciphertext:$ciphertext, DefaultValuedAttr:$levelToDrop ); } @@ -289,10 +242,10 @@ def RotOp : Openfhe_Op<"rot", [ ]> { let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$ciphertext, + Openfhe_Ciphertext:$ciphertext, Builtin_IntegerAttr:$index ); - let results = (outs LWECiphertext:$output); + let results = (outs Openfhe_Ciphertext:$output); } def AutomorphOp : Openfhe_Op<"automorph", [ @@ -301,10 +254,10 @@ def AutomorphOp : Openfhe_Op<"automorph", [ ]> { let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$ciphertext, + Openfhe_Ciphertext:$ciphertext, Openfhe_EvalKey:$evalKey ); - let results = (outs LWECiphertext:$output); + let results = (outs Openfhe_Ciphertext:$output); } def KeySwitchOp : Openfhe_Op<"key_switch", [ @@ -313,18 +266,18 @@ def KeySwitchOp : Openfhe_Op<"key_switch", [ ]> { let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$ciphertext, + Openfhe_Ciphertext:$ciphertext, Openfhe_EvalKey:$evalKey ); - let results = (outs LWECiphertext:$output); + let results = (outs Openfhe_Ciphertext:$output); } -def BootstrapOp : Openfhe_UnaryTypeSwitchOp<"bootstrap"> { let summary = "OpenFHE bootstrap operation of a ciphertext. (For CKKS)"; } +def BootstrapOp : Openfhe_UnaryOp<"bootstrap"> { let summary = "OpenFHE bootstrap operation of a ciphertext. (For CKKS)"; } def FastRotationPrecomputeOp : Openfhe_Op<"fast_rotation_precompute", [Pure]> { let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$input + Openfhe_Ciphertext:$input ); let results = (outs Openfhe_DigitDecomposition:$output); } @@ -335,12 +288,32 @@ def FastRotationOp : Openfhe_Op<"fast_rotation", [Pure, ]> { let arguments = (ins Openfhe_CryptoContext:$cryptoContext, - LWECiphertext:$input, + Openfhe_Ciphertext:$input, IndexAttr:$index, IndexAttr:$cyclotomicOrder, Openfhe_DigitDecomposition:$precomputedDigitDecomp ); - let results = (outs LWECiphertext:$output); + let results = (outs Openfhe_Ciphertext:$output); +} + +def DecodeOp : Openfhe_Op<"decode", [ + Pure, +]> { + let arguments = (ins Openfhe_Plaintext:$input); + let results = (outs AnyType:$output); + let assemblyFormat = [{ + operands attr-dict `:` type($input) `->` type($output) + }]; +} + +def DecodeCKKSOp : Openfhe_Op<"decode_ckks", [ + Pure, +]> { + let arguments = (ins Openfhe_Plaintext:$input); + let results = (outs AnyType:$output); + let assemblyFormat = [{ + operands attr-dict `:` type($input) `->` type($output) + }]; } #endif // LIB_DIALECT_OPENFHE_IR_OPENFHEOPS_TD_ diff --git a/lib/Dialect/Openfhe/IR/OpenfheTypes.h b/lib/Dialect/Openfhe/IR/OpenfheTypes.h index ff979c475b..06de7a6d6d 100644 --- a/lib/Dialect/Openfhe/IR/OpenfheTypes.h +++ b/lib/Dialect/Openfhe/IR/OpenfheTypes.h @@ -1,8 +1,11 @@ #ifndef LIB_DIALECT_OPENFHE_IR_OPENFHETYPES_H_ #define LIB_DIALECT_OPENFHE_IR_OPENFHETYPES_H_ +// IWYU pragma: begin_keep +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Openfhe/IR/OpenfheDialect.h" #include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project +// IWYU pragma: end_keep #define GET_TYPEDEF_CLASSES #include "lib/Dialect/Openfhe/IR/OpenfheTypes.h.inc" diff --git a/lib/Dialect/Openfhe/IR/OpenfheTypes.td b/lib/Dialect/Openfhe/IR/OpenfheTypes.td index 70f4115f7c..681e654976 100644 --- a/lib/Dialect/Openfhe/IR/OpenfheTypes.td +++ b/lib/Dialect/Openfhe/IR/OpenfheTypes.td @@ -3,6 +3,7 @@ include "OpenfheDialect.td" +include "lib/Dialect/HEIRInterfaces.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/CommonTypeConstraints.td" include "mlir/IR/DialectBase.td" @@ -28,6 +29,16 @@ class Openfhe_Type traits = []> }]; } +def Openfhe_Ciphertext : Openfhe_Type<"Ciphertext", "ciphertext", [SecretTypeInterface]> { + let summary = "An opaque OpenFHE ciphertext type"; + let asmName = "ct"; +} + +def Openfhe_Plaintext : Openfhe_Type<"Plaintext", "plaintext"> { + let summary = "An opaque OpenFHE plaintext type"; + let asmName = "pt"; +} + def Openfhe_PublicKey : Openfhe_Type<"PublicKey", "public_key"> { let summary = "The public key required to encrypt plaintext in OpenFHE."; let asmName = "pk"; diff --git a/lib/Dialect/Openfhe/Transforms/BUILD b/lib/Dialect/Openfhe/Transforms/BUILD index fd6d076c57..6e5ed11df3 100644 --- a/lib/Dialect/Openfhe/Transforms/BUILD +++ b/lib/Dialect/Openfhe/Transforms/BUILD @@ -28,16 +28,17 @@ cc_library( ], deps = [ ":pass_inc_gen", + "@heir//lib/Analysis/MulDepthAnalysis", + "@heir//lib/Analysis/SecretnessAnalysis", "@heir//lib/Dialect:ModuleAttributes", "@heir//lib/Dialect/BGV/IR:Dialect", "@heir//lib/Dialect/CKKS/IR:Dialect", - "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Mgmt/IR:Dialect", - "@heir//lib/Dialect/ModArith/IR:Dialect", "@heir//lib/Dialect/Openfhe/IR:Dialect", - "@heir//lib/Dialect/RNS/IR:Dialect", + "@heir//lib/Utils", "@heir//lib/Utils:TransformUtils", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp b/lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp index 970f7efe83..92c54ff187 100644 --- a/lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp +++ b/lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.cpp @@ -4,33 +4,41 @@ #include #include +#include "lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.h" +#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" #include "lib/Dialect/BGV/IR/BGVAttributes.h" #include "lib/Dialect/BGV/IR/BGVDialect.h" #include "lib/Dialect/BGV/IR/BGVEnums.h" #include "lib/Dialect/CKKS/IR/CKKSAttributes.h" #include "lib/Dialect/CKKS/IR/CKKSDialect.h" -#include "lib/Dialect/LWE/IR/LWETypes.h" +#include "lib/Dialect/CKKS/IR/CKKSEnums.h" #include "lib/Dialect/Mgmt/IR/MgmtAttributes.h" #include "lib/Dialect/Mgmt/IR/MgmtDialect.h" -#include "lib/Dialect/ModArith/IR/ModArithTypes.h" #include "lib/Dialect/ModuleAttributes.h" #include "lib/Dialect/Openfhe/IR/OpenfheOps.h" #include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" -#include "lib/Dialect/RNS/IR/RNSTypes.h" #include "lib/Utils/TransformUtils.h" -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project -#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project -#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "lib/Utils/Utils.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Support/WalkResult.h" // from @llvm-project + +#define DEBUG_TYPE "openfhe-configure-crypto-context" namespace mlir { namespace heir { @@ -233,8 +241,9 @@ struct ConfigureCryptoContext auto module = op->getParentOfType(); // remove bgv.schemeParam attribute if present - // fill encryptionTechniqueExtended + // fill encryptionTechniqueExtended and plaintextModulus config.encryptionTechniqueExtended = false; + config.plaintextModulus = 0; if (auto schemeParamAttr = module->getAttrOfType( bgv::BGVDialect::kSchemeParamAttrName)) { if (moduleIsBGV(module) && schemeParamAttr.getEncryptionTechnique() == @@ -246,10 +255,15 @@ struct ConfigureCryptoContext config.encryptionTechniqueExtended = schemeParamAttr.getEncryptionTechnique() == bgv::BGVEncryptionTechnique::extended; + if (!moduleIsCKKS(module)) + config.plaintextModulus = schemeParamAttr.getPlaintextModulus(); module->removeAttr(bgv::BGVDialect::kSchemeParamAttrName); } // remove ckks.schemeParam attribute if present + // For CKKS, plainMod must be 0 to avoid codegen for SetPlaintextModulus. + // OpenFHE will throw an exception if you try to set the plaintext modulus + // in CKKS. if (auto schemeParamAttr = module->getAttrOfType( ckks::CKKSDialect::kSchemeParamAttrName)) { if (schemeParamAttr.getEncryptionTechnique() == @@ -261,20 +275,34 @@ struct ConfigureCryptoContext module->removeAttr(ckks::CKKSDialect::kSchemeParamAttrName); } - /// Compute muldepth from multiply aspects... - // get mulDepth from function argument ciphertext type - for (auto arg : op.getArguments()) { - if (auto argType = dyn_cast( - getElementTypeOrSelf(arg.getType()))) { - if (auto rnsType = dyn_cast( - argType.getCiphertextSpace().getRing().getCoefficientType())) { - config.mulDepth = rnsType.getBasisTypes().size() - 1; - // implicitly assume arguments have the same level - break; - } - } + LLVM_DEBUG(llvm::dbgs() << "Recomputing mul depth\n"); + DataFlowSolver solver; + dataflow::loadBaselineAnalyses(solver); + solver.load(); + solver.load(); + + if (failed(solver.initializeAndRun(op))) { + op->emitOpError() << "Failed to run mul depth analysis.\n"; + return failure(); } + config.mulDepth = 0; + walkValues(op, [&](Value value) { + auto mulDepthState = + solver.lookupState(value)->getValue(); + if (!mulDepthState.isInitialized()) { + LLVM_DEBUG(llvm::dbgs() + << "mul depth uninitialized at " << value << "\n"); + return; + } + auto mulDepth = mulDepthState.getMulDepth(); + if (mulDepth > config.mulDepth) { + LLVM_DEBUG(llvm::dbgs() + << "Found larger mul depth=" << mulDepth << "\n"); + config.mulDepth = mulDepth; + } + }); + config.hasBootstrapOp = hasBootstrapOp(op); // TODO(#1207): determine mulDepth earlier in mgmt level // approxModDepth = 14, this solely depends on secretKeyDist @@ -293,25 +321,6 @@ struct ConfigureCryptoContext config.hasRelinOp = hasRelinOp(op); config.rotIndices = findAllRotIndices(op); - // Get plaintext modulus from function argument ciphertext type for CKKS, - // plainMod must be 0 to avoid codegen for SetPlaintextModulus. OpenFHE will - // throw an exception if you try to set the plaintext modulus in CKKS. - if (moduleIsCKKS(module)) { - config.plaintextModulus = 0; - } else { - for (auto arg : op.getArguments()) { - if (auto argType = dyn_cast( - getElementTypeOrSelf(arg.getType()))) { - if (auto modArithType = dyn_cast( - argType.getPlaintextSpace().getRing().getCoefficientType())) { - config.plaintextModulus = modArithType.getModulus().getInt(); - // implicitly assume arguments have the same plaintext modulus - break; - } - } - } - } - // get evalAddCount/KeySwitchCount from func attribute, if present config.evalAddCount = 0; config.keySwitchCount = 0; diff --git a/lib/Dialect/Openfhe/Transforms/FastRotationPrecompute.cpp b/lib/Dialect/Openfhe/Transforms/FastRotationPrecompute.cpp index cbe941e317..82c7f5603f 100644 --- a/lib/Dialect/Openfhe/Transforms/FastRotationPrecompute.cpp +++ b/lib/Dialect/Openfhe/Transforms/FastRotationPrecompute.cpp @@ -2,7 +2,6 @@ #include -#include "lib/Dialect/LWE/IR/LWETypes.h" #include "lib/Dialect/Openfhe/IR/OpenfheOps.h" #include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" #include "lib/Utils/ConversionUtils.h" @@ -66,13 +65,7 @@ void processFunc(func::FuncOp funcOp, Value cryptoContext) { // dimension used by OpenFHE. However, OpenFHE sets its own parameters, // and so this ends up being ignored in favor of dynamically reading // `cc->GetRingDimension() * 2`. - int cyclotomicOrder = - 2 * cast(ciphertext.getType()) - .getCiphertextSpace() - .getRing() - .getPolynomialModulus() - .getPolynomial() - .getDegree(); + int cyclotomicOrder = 0; auto fastRot = FastRotationOp::create( builder, op->getLoc(), op.getType(), op.getCryptoContext(), op.getCiphertext(), op.getIndex(), diff --git a/lib/Dialect/Secret/IR/SecretDialect.cpp b/lib/Dialect/Secret/IR/SecretDialect.cpp index 0397615885..db48c0774b 100644 --- a/lib/Dialect/Secret/IR/SecretDialect.cpp +++ b/lib/Dialect/Secret/IR/SecretDialect.cpp @@ -2,7 +2,6 @@ // IWYU pragma: begin_keep #include "lib/Dialect/Secret/IR/SecretAttributes.h" -#include "lib/Dialect/Secret/IR/SecretDialect.cpp.inc" #include "lib/Dialect/Secret/IR/SecretOps.h" #include "lib/Dialect/Secret/IR/SecretTypes.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project @@ -10,6 +9,7 @@ #include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project // IWYU pragma: end_keep +#include "lib/Dialect/Secret/IR/SecretDialect.cpp.inc" #define GET_ATTRDEF_CLASSES #include "lib/Dialect/Secret/IR/SecretAttributes.cpp.inc" #define GET_TYPEDEF_CLASSES diff --git a/lib/Dialect/Secret/IR/SecretTypes.h b/lib/Dialect/Secret/IR/SecretTypes.h index 3cc1536212..587b195392 100644 --- a/lib/Dialect/Secret/IR/SecretTypes.h +++ b/lib/Dialect/Secret/IR/SecretTypes.h @@ -1,7 +1,10 @@ #ifndef LIB_DIALECT_SECRET_IR_SECRETTYPES_H_ #define LIB_DIALECT_SECRET_IR_SECRETTYPES_H_ +// IWYU pragma: begin_keep +#include "lib/Dialect/HEIRInterfaces.h" #include "lib/Dialect/Secret/IR/SecretDialect.h" +// IWYU pragma: end_keep #define GET_TYPEDEF_CLASSES #include "lib/Dialect/Secret/IR/SecretTypes.h.inc" diff --git a/lib/Dialect/Secret/IR/SecretTypes.td b/lib/Dialect/Secret/IR/SecretTypes.td index ac2b5bbb4f..386a3feb74 100644 --- a/lib/Dialect/Secret/IR/SecretTypes.td +++ b/lib/Dialect/Secret/IR/SecretTypes.td @@ -3,16 +3,17 @@ include "SecretDialect.td" -include "mlir/IR/DialectBase.td" +include "lib/Dialect/HEIRInterfaces.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/DialectBase.td" // A base class for all types in this dialect -class Secret_Type - : TypeDef { +class Secret_Type traits = []> + : TypeDef { let mnemonic = typeMnemonic; } -def Secret : Secret_Type<"Secret", "secret"> { +def Secret : Secret_Type<"Secret", "secret", [SecretTypeInterface]> { let summary = "A secret value"; let description = [{ diff --git a/lib/Target/OpenFhePke/BUILD b/lib/Target/OpenFhePke/BUILD index 94f7afb093..1b6b8bb1ed 100644 --- a/lib/Target/OpenFhePke/BUILD +++ b/lib/Target/OpenFhePke/BUILD @@ -26,7 +26,6 @@ cc_library( "@heir//lib/Dialect/Openfhe/IR:Dialect", "@heir//lib/Dialect/Polynomial/IR:Dialect", "@heir//lib/Dialect/RNS/IR:Dialect", - "@heir//lib/Dialect/RNS/IR:RNSTypeInterfaces", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", @@ -49,7 +48,6 @@ cc_library( ], deps = [ "@heir//lib/Analysis/SelectVariableNames", - "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Openfhe/IR:Dialect", "@heir//lib/Utils:TargetUtils", "@llvm-project//llvm:Support", @@ -70,7 +68,6 @@ cc_library( "@cereal", "@heir//lib/Analysis/SelectVariableNames", "@heir//lib/Dialect:ModuleAttributes", - "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Openfhe/IR:Dialect", "@heir//lib/Utils:TargetUtils", "@llvm-project//llvm:Support", @@ -128,10 +125,9 @@ cc_library( "//conditions:default": [], }), deps = [ - "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Dialect/Mgmt/IR:Dialect", "@heir//lib/Dialect/ModArith/IR:Dialect", "@heir//lib/Dialect/Openfhe/IR:Dialect", - "@heir//lib/Dialect/Polynomial/IR:Dialect", "@heir//lib/Dialect/RNS/IR:Dialect", "@heir//lib/Dialect/RNS/IR:RNSTypeInterfaces", # buildcleaner: keep "@heir//lib/Dialect/TensorExt/IR:Dialect", diff --git a/lib/Target/OpenFhePke/Interpreter.cpp b/lib/Target/OpenFhePke/Interpreter.cpp index 624a02e0f0..5b5f41d8b8 100644 --- a/lib/Target/OpenFhePke/Interpreter.cpp +++ b/lib/Target/OpenFhePke/Interpreter.cpp @@ -10,13 +10,11 @@ #include #include -#include "lib/Dialect/LWE/IR/LWEAttributes.h" -#include "lib/Dialect/LWE/IR/LWEDialect.h" -#include "lib/Dialect/LWE/IR/LWEOps.h" -#include "lib/Dialect/LWE/IR/LWETypes.h" +#include "lib/Dialect/Mgmt/IR/MgmtDialect.h" #include "lib/Dialect/ModArith/IR/ModArithDialect.h" #include "lib/Dialect/Openfhe/IR/OpenfheDialect.h" #include "lib/Dialect/Openfhe/IR/OpenfheOps.h" +#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" #include "lib/Dialect/RNS/IR/RNSDialect.h" #include "lib/Dialect/RNS/IR/RNSTypeInterfaces.h" #include "lib/Dialect/TensorExt/IR/TensorExtDialect.h" @@ -178,7 +176,8 @@ void Interpreter::initializeDispatchTable() { REGISTER_OP(scf::YieldOp); REGISTER_OP(affine::AffineForOp); REGISTER_OP(affine::AffineYieldOp); - REGISTER_OP(lwe::RLWEDecodeOp); + REGISTER_OP(DecodeOp); + REGISTER_OP(DecodeCKKSOp); REGISTER_OP(AddOp); REGISTER_OP(AddPlainOp); REGISTER_OP(SubOp); @@ -216,8 +215,8 @@ void Interpreter::initializeDispatchTable() { void Interpreter::eraseValue(Value v) { llvm::TypeSwitch(v.getType()) - .Case([&](auto ty) { plaintexts.erase(v); }) - .Case([&](auto ty) { ciphertexts.erase(v); }) + .Case([&](auto ty) { plaintexts.erase(v); }) + .Case([&](auto ty) { ciphertexts.erase(v); }) .Case([&](auto ty) { auto elemType = ty.getElementType(); if (elemType.isInteger() || elemType.isIndex()) { @@ -226,9 +225,9 @@ void Interpreter::eraseValue(Value v) { floatVectors.erase(v); } else if (elemType.isF64()) { doubleVectors.erase(v); - } else if (isa(elemType)) { + } else if (isa(elemType)) { plaintextVectors.erase(v); - } else if (isa(elemType)) { + } else if (isa(elemType)) { ciphertextVectors.erase(v); } else { llvm::errs() << "Unsupported tensor element type " << elemType @@ -315,11 +314,11 @@ TypedCppValue Interpreter::loadTypedValue(Value v) { TypedCppValue result; llvm::TypeSwitch(v.getType()) - .Case([&](auto ty) { + .Case([&](auto ty) { if (auto it = plaintexts.find(v); it != plaintexts.end()) result = TypedCppValue(it->second); }) - .Case([&](auto ty) { + .Case([&](auto ty) { if (auto it = ciphertexts.find(v); it != ciphertexts.end()) result = TypedCppValue(it->second); }) @@ -334,10 +333,10 @@ TypedCppValue Interpreter::loadTypedValue(Value v) { } else if (elemType.isF64()) { if (auto it = doubleVectors.find(v); it != doubleVectors.end()) result = TypedCppValue(it->second); - } else if (isa(elemType)) { + } else if (isa(elemType)) { if (auto it = plaintextVectors.find(v); it != plaintextVectors.end()) result = TypedCppValue(it->second); - } else if (isa(elemType)) { + } else if (isa(elemType)) { if (auto it = ciphertextVectors.find(v); it != ciphertextVectors.end()) result = TypedCppValue(it->second); @@ -799,7 +798,7 @@ void Interpreter::visit(tensor::EmptyOp op) { } else if (elementType.isF64()) { doubleVectors[op.getResult()] = std::make_shared>(numElements); - } else if (isa(elementType)) { + } else if (isa(elementType)) { plaintextVectors[op.getResult()] = std::make_shared>(numElements); } else { @@ -819,7 +818,7 @@ void Interpreter::visit(tensor::ExtractOp op) { floatValues[op.getResult()] = (*floatVectors.at(op.getTensor()))[index]; } else if (elemType.isF64()) { doubleValues[op.getResult()] = (*doubleVectors.at(op.getTensor()))[index]; - } else if (isa(elemType)) { + } else if (isa(elemType)) { plaintexts[op.getResult()] = (*plaintextVectors.at(op.getTensor()))[index]; } else { ciphertexts[op.getResult()] = @@ -854,7 +853,7 @@ void Interpreter::visit(tensor::InsertOp op) { : std::make_shared>(*srcVec); (*vec)[index] = doubleValues.at(op.getScalar()); doubleVectors[op.getResult()] = vec; - } else if (isa(elemType)) { + } else if (isa(elemType)) { auto srcVec = plaintextVectors.at(op.getDest()); auto vec = canModifyInPlace ? srcVec @@ -919,7 +918,7 @@ void Interpreter::visit(tensor::FromElementsOp op) { (*result)[i] = doubleValues.at(elements[i]); } doubleVectors[op.getResult()] = result; - } else if (isa(elemType)) { + } else if (isa(elemType)) { auto result = std::make_shared>(elements.size()); for (size_t i = 0; i < elements.size(); ++i) { (*result)[i] = plaintexts.at(elements[i]); @@ -976,7 +975,7 @@ void Interpreter::visit(tensor::ConcatOp op) { result->insert(result->end(), vec.begin(), vec.end()); } doubleVectors[op.getResult()] = result; - } else if (isa(elemType)) { + } else if (isa(elemType)) { size_t totalSize = 0; for (auto input : inputs) { totalSize += plaintextVectors.at(input)->size(); @@ -1014,7 +1013,7 @@ void Interpreter::visit(tensor::CollapseShapeOp op) { floatVectors[op.getResult()] = floatVectors.at(op.getSrc()); } else if (elemType.isF64()) { doubleVectors[op.getResult()] = doubleVectors.at(op.getSrc()); - } else if (isa(elemType)) { + } else if (isa(elemType)) { plaintextVectors[op.getResult()] = plaintextVectors.at(op.getSrc()); } else { ciphertextVectors[op.getResult()] = ciphertextVectors.at(op.getSrc()); @@ -1032,7 +1031,7 @@ void Interpreter::visit(tensor::ExpandShapeOp op) { floatVectors[op.getResult()] = floatVectors.at(op.getSrc()); } else if (elemType.isF64()) { doubleVectors[op.getResult()] = doubleVectors.at(op.getSrc()); - } else if (isa(elemType)) { + } else if (isa(elemType)) { plaintextVectors[op.getResult()] = plaintextVectors.at(op.getSrc()); } else { ciphertextVectors[op.getResult()] = ciphertextVectors.at(op.getSrc()); @@ -1105,7 +1104,7 @@ void Interpreter::visit(tensor::ExtractSliceOp op) { } doubleVectors[op.getResult()] = std::make_shared>(std::move(result)); - } else if (isa(elemType)) { + } else if (isa(elemType)) { auto result = std::vector(totalElements); const auto& srcVec = *plaintextVectors.at(op.getSource()); for (int64_t i = 0; i < totalElements; ++i) { @@ -1113,7 +1112,7 @@ void Interpreter::visit(tensor::ExtractSliceOp op) { } plaintextVectors[op.getResult()] = std::make_shared>(std::move(result)); - } else if (isa(elemType)) { + } else if (isa(elemType)) { auto result = std::vector(totalElements); const auto& srcVec = *ciphertextVectors.at(op.getSource()); for (int64_t i = 0; i < totalElements; ++i) { @@ -1197,7 +1196,7 @@ void Interpreter::visit(tensor::InsertSliceOp op) { (*destVec)[insertElement(i)] = srcVec[i]; } doubleVectors[op.getResult()] = std::move(destVec); - } else if (isa(elemType)) { + } else if (isa(elemType)) { auto srcDestVec = plaintextVectors.at(op.getDest()); auto destVec = canModifyInPlace ? srcDestVec @@ -1207,7 +1206,7 @@ void Interpreter::visit(tensor::InsertSliceOp op) { (*destVec)[insertElement(i)] = srcVec[i]; } plaintextVectors[op.getResult()] = std::move(destVec); - } else if (isa(elemType)) { + } else if (isa(elemType)) { auto srcDestVec = ciphertextVectors.at(op.getDest()); auto destVec = canModifyInPlace @@ -1536,13 +1535,13 @@ void Interpreter::visit(AddPlainOp op) { auto lhsVal = op.getLhs(); auto rhsVal = op.getRhs(); - if (isa(lhsVal.getType()) && - isa(rhsVal.getType())) { + if (isa(lhsVal.getType()) && + isa(rhsVal.getType())) { auto lhsCt = ciphertexts.at(lhsVal); auto rhsPt = plaintexts.at(rhsVal); TIME_OPERATION("AddPlain", op.getOutput(), cc->EvalAdd(lhsCt, rhsPt)); - } else if (isa(lhsVal.getType()) && - isa(rhsVal.getType())) { + } else if (isa(lhsVal.getType()) && + isa(rhsVal.getType())) { auto lhsPt = plaintexts.at(lhsVal); auto rhsCt = ciphertexts.at(rhsVal); TIME_OPERATION("AddPlain", op.getOutput(), cc->EvalAdd(lhsPt, rhsCt)); @@ -1557,13 +1556,13 @@ void Interpreter::visit(SubPlainOp op) { auto rhsVal = op.getRhs(); // Check which is ciphertext and which is plaintext - if (isa(lhsVal.getType()) && - isa(rhsVal.getType())) { + if (isa(lhsVal.getType()) && + isa(rhsVal.getType())) { auto lhsCt = ciphertexts.at(lhsVal); auto rhsPt = plaintexts.at(rhsVal); TIME_OPERATION("SubPlain", op.getOutput(), cc->EvalSub(lhsCt, rhsPt)); - } else if (isa(lhsVal.getType()) && - isa(rhsVal.getType())) { + } else if (isa(lhsVal.getType()) && + isa(rhsVal.getType())) { auto lhsPt = plaintexts.at(lhsVal); auto rhsCt = ciphertexts.at(rhsVal); TIME_OPERATION("SubPlain", op.getOutput(), cc->EvalSub(lhsPt, rhsCt)); @@ -1801,11 +1800,12 @@ void Interpreter::visit(FastRotationPrecomputeOp op) { cc->EvalFastRotationPrecompute(ct), fastRotPrecomps); } -void Interpreter::visit(lwe::RLWEDecodeOp op) { - auto plaintext = plaintexts.at(op.getInput()); - bool isCKKS = llvm::isa(op.getEncoding()); +void Interpreter::decodeCore(Operation* op, Value input, Value result, + bool isCKKS) { + auto plaintext = plaintexts.at(input); - if (auto tensorTy = dyn_cast(op.getResult().getType())) { + // Tensor case + if (auto tensorTy = dyn_cast(result.getType())) { auto shape = tensorTy.getShape(); auto nonUnitDims = llvm::count_if(shape, [](auto dim) { return dim != 1; }); if (nonUnitDims != 1) { @@ -1828,45 +1828,56 @@ void Interpreter::visit(lwe::RLWEDecodeOp op) { auto elemType = tensorTy.getElementType(); if (elemType.isF64()) { - auto result = std::make_shared>(); - result->reserve(ckksValues.size()); + auto res = std::make_shared>(); + res->reserve(ckksValues.size()); for (const auto& val : ckksValues) { - result->push_back(val.real()); + res->push_back(val.real()); } - doubleVectors[op.getResult()] = result; + doubleVectors[result] = res; } else { - auto result = std::make_shared>(); - result->reserve(ckksValues.size()); + auto res = std::make_shared>(); + res->reserve(ckksValues.size()); for (const auto& val : ckksValues) { - result->push_back(static_cast(val.real())); + res->push_back(static_cast(val.real())); } - floatVectors[op.getResult()] = result; + floatVectors[result] = res; } - } else { - auto packedValues = plaintext->GetPackedValue(); - auto result = std::make_shared>(); - result->reserve(packedValues.size()); - for (const auto& val : packedValues) { - result->push_back(static_cast(val)); - } - intVectors[op.getResult()] = result; + return; } - } else { - // Scalar result - if (isCKKS) { - auto ckksValues = plaintext->GetCKKSPackedValue(); - auto elemType = op.getResult().getType(); - if (elemType.isF64()) { - doubleValues[op.getResult()] = ckksValues[0].real(); - } else { - floatValues[op.getResult()] = static_cast(ckksValues[0].real()); - } + auto packedValues = plaintext->GetPackedValue(); + auto res = std::make_shared>(); + res->reserve(packedValues.size()); + for (const auto& val : packedValues) { + res->push_back(static_cast(val)); + } + intVectors[result] = res; + return; + } + + // Scalar result + if (isCKKS) { + auto ckksValues = plaintext->GetCKKSPackedValue(); + auto elemType = result.getType(); + + if (elemType.isF64()) { + doubleValues[result] = ckksValues[0].real(); } else { - auto packedValues = plaintext->GetPackedValue(); - intValues[op.getResult()] = static_cast(packedValues[0]); + floatValues[result] = static_cast(ckksValues[0].real()); } + return; } + + auto packedValues = plaintext->GetPackedValue(); + intValues[result] = static_cast(packedValues[0]); +} + +void Interpreter::visit(DecodeOp op) { + decodeCore(op, op.getInput(), op.getResult(), false); +} + +void Interpreter::visit(DecodeCKKSOp op) { + decodeCore(op, op.getInput(), op.getResult(), true); } void initContext(MLIRContext& context) { @@ -1874,9 +1885,8 @@ void initContext(MLIRContext& context) { registry.insert(); registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); - registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/lib/Target/OpenFhePke/Interpreter.h b/lib/Target/OpenFhePke/Interpreter.h index 71b7181786..3948d47b90 100644 --- a/lib/Target/OpenFhePke/Interpreter.h +++ b/lib/Target/OpenFhePke/Interpreter.h @@ -8,7 +8,6 @@ #include #include -#include "lib/Dialect/LWE/IR/LWEOps.h" #include "lib/Dialect/Openfhe/IR/OpenfheOps.h" #include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project #include "mlir/include/mlir/Analysis/Liveness.h" // from @llvm-project @@ -160,6 +159,8 @@ class Interpreter { void visit(AddPlainOp op); void visit(AutomorphOp op); void visit(BootstrapOp op); + void visit(DecodeCKKSOp op); + void visit(DecodeOp op); void visit(DecryptOp op); void visit(EncryptOp op); void visit(FastRotationOp op); @@ -186,9 +187,6 @@ class Interpreter { void visit(SubOp op); void visit(SubPlainOp op); - // Other HEIR ops - void visit(lwe::RLWEDecodeOp op); - int getFlattenedTensorIndex(Value tensor, ValueRange indices); private: @@ -198,6 +196,9 @@ class Interpreter { // Helper to convert TypedCppValue to type-specific storage (for inputs) void storeTypedValue(Value v, const TypedCppValue& typedVal); + // Helper for decoding + void decodeCore(Operation* op, Value input, Value result, bool isCKKS); + // Helper to convert from type-specific storage to TypedCppValue (for outputs) TypedCppValue loadTypedValue(Value v); ModuleOp module; diff --git a/lib/Target/OpenFhePke/InterpreterTest.cpp b/lib/Target/OpenFhePke/InterpreterTest.cpp index 4b04a7fcda..6573105c43 100644 --- a/lib/Target/OpenFhePke/InterpreterTest.cpp +++ b/lib/Target/OpenFhePke/InterpreterTest.cpp @@ -503,18 +503,8 @@ struct CryptoSetup { // Common LWE type definitions header for MLIR tests static const char* kLWETypesHeader = R"mlir( -!Z65537 = !mod_arith.int<65537 : i64> -!Z1095233372161 = !mod_arith.int<1095233372161 : i64> -!rns_L0 = !rns.rns -#ring_pt = #polynomial.ring> -#ring_ct = #polynomial.ring> -#encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> -#modulus_chain = #lwe.modulus_chain, current = 0> -#plaintext_space = #lwe.plaintext_space -#ciphertext_space = #lwe.ciphertext_space -!ct = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space, key = #key, modulus_chain = #modulus_chain> -!pt = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> +!ct = !openfhe.ciphertext +!pt = !openfhe.plaintext )mlir"; TEST(InterpreterTest, TestOpenfheAdd) { @@ -874,7 +864,7 @@ TEST(InterpreterTest, TestOpenfheRLWEDecodeBGVScalar) { std::string mlirStr = std::string(kLWETypesHeader) + R"mlir( module attributes {scheme.bgv} { func.func @main(%pt: !pt) -> i32 { - %result = lwe.rlwe_decode %pt {encoding = #encoding, ring = #ring_pt} : !pt -> i32 + %result = openfhe.decode %pt : !pt -> i32 return %result : i32 } } @@ -900,7 +890,7 @@ TEST(InterpreterTest, TestOpenfheRLWEDecodeBGVTensor) { std::string mlirStr = std::string(kLWETypesHeader) + R"mlir( module attributes {scheme.bgv} { func.func @main(%pt: !pt) -> tensor<8xi32> { - %result = lwe.rlwe_decode %pt {encoding = #encoding, ring = #ring_pt} : !pt -> tensor<8xi32> + %result = openfhe.decode %pt : !pt -> tensor<8xi32> return %result : tensor<8xi32> } } @@ -950,19 +940,11 @@ TEST(InterpreterTest, TestOpenfheRLWEDecodeCKKSScalar) { MLIRContext context; initContext(context); std::string mlirStr = R"mlir( -!Z65537 = !mod_arith.int<65537 : i64> -!Z1095233372161 = !mod_arith.int<1095233372161 : i64> -!rns_L0 = !rns.rns -#ring_pt = #polynomial.ring> -#ckks_encoding = #lwe.inverse_canonical_encoding -#key = #lwe.key<> -#modulus_chain = #lwe.modulus_chain, current = 0> -#plaintext_space = #lwe.plaintext_space -!pt = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> +!pt = !openfhe.plaintext module attributes {scheme.ckks} { func.func @main(%pt: !pt) -> f32 { - %result = lwe.rlwe_decode %pt {encoding = #ckks_encoding, ring = #ring_pt} : !pt -> f32 + %result = openfhe.decode_ckks %pt : !pt -> f32 return %result : f32 } } @@ -987,19 +969,11 @@ TEST(InterpreterTest, TestOpenfheRLWEDecodeCKKSTensor) { MLIRContext context; initContext(context); std::string mlirStr = R"mlir( -!Z65537 = !mod_arith.int<65537 : i64> -!Z1095233372161 = !mod_arith.int<1095233372161 : i64> -!rns_L0 = !rns.rns -#ring_pt = #polynomial.ring> -#ckks_encoding = #lwe.inverse_canonical_encoding -#key = #lwe.key<> -#modulus_chain = #lwe.modulus_chain, current = 0> -#plaintext_space = #lwe.plaintext_space -!pt = !lwe.lwe_plaintext>, plaintext_space = #plaintext_space> +!pt = !openfhe.plaintext module attributes {scheme.ckks} { func.func @main(%pt: !pt) -> tensor<8xf32> { - %result = lwe.rlwe_decode %pt {encoding = #ckks_encoding, ring = #ring_pt} : !pt -> tensor<8xf32> + %result = openfhe.decode_ckks %pt : !pt -> tensor<8xf32> return %result : tensor<8xf32> } } @@ -1186,19 +1160,11 @@ TEST(InterpreterTest, TestOpenfheRLWEDecodeCKKSScalarDouble) { MLIRContext context; initContext(context); std::string mlirStr = R"mlir( -!Z65537 = !mod_arith.int<65537 : i64> -!Z1095233372161 = !mod_arith.int<1095233372161 : i64> -!rns_L0 = !rns.rns -#ring_pt = #polynomial.ring> -#ckks_encoding = #lwe.inverse_canonical_encoding -#key = #lwe.key<> -#modulus_chain = #lwe.modulus_chain, current = 0> -#plaintext_space = #lwe.plaintext_space -!pt = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> +!pt = !openfhe.plaintext module attributes {scheme.ckks} { func.func @main(%pt: !pt) -> f64 { - %result = lwe.rlwe_decode %pt {encoding = #ckks_encoding, ring = #ring_pt} : !pt -> f64 + %result = openfhe.decode_ckks %pt : !pt -> f64 return %result : f64 } } @@ -1223,19 +1189,11 @@ TEST(InterpreterTest, TestOpenfheRLWEDecodeCKKSTensorDouble) { MLIRContext context; initContext(context); std::string mlirStr = R"mlir( -!Z65537 = !mod_arith.int<65537 : i64> -!Z1095233372161 = !mod_arith.int<1095233372161 : i64> -!rns_L0 = !rns.rns -#ring_pt = #polynomial.ring> -#ckks_encoding = #lwe.inverse_canonical_encoding -#key = #lwe.key<> -#modulus_chain = #lwe.modulus_chain, current = 0> -#plaintext_space = #lwe.plaintext_space -!pt = !lwe.lwe_plaintext>, plaintext_space = #plaintext_space> +!pt = !openfhe.plaintext module attributes {scheme.ckks} { func.func @main(%pt: !pt) -> tensor<8xf64> { - %result = lwe.rlwe_decode %pt {encoding = #ckks_encoding, ring = #ring_pt} : !pt -> tensor<8xf64> + %result = openfhe.decode_ckks %pt : !pt -> tensor<8xf64> return %result : tensor<8xf64> } } diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp index 76334dba63..930458d761 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp @@ -16,11 +16,9 @@ #include "include/cereal/archives/portable_binary.hpp" // from @cereal #include "include/cereal/cereal.hpp" // from @cereal #include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" -#include "lib/Dialect/LWE/IR/LWEAttributes.h" -#include "lib/Dialect/LWE/IR/LWEOps.h" -#include "lib/Dialect/LWE/IR/LWETypes.h" #include "lib/Dialect/ModuleAttributes.h" #include "lib/Dialect/Openfhe/IR/OpenfheOps.h" +#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" #include "lib/Target/OpenFhePke/OpenFheUtils.h" #include "lib/Utils/TargetUtils.h" #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project @@ -250,8 +248,7 @@ LogicalResult OpenFhePkeEmitter::translate(Operation& op) { tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::FromElementsOp, tensor::ConcatOp>( [&](auto op) { return printOperation(op); }) - // LWE ops - .Case( + .Case( [&](auto op) { return printOperation(op); }) // OpenFHE ops .Case(op.getResult().getType())) { + if (isa(op.getResult().getType())) { emitAutoAssignPrefix(op.getResult()); } else { if (failed(emitTypedAssignPrefix(op.getResult(), op.getLoc(), true))) @@ -1130,8 +1127,7 @@ LogicalResult OpenFhePkeEmitter::printOperation(tensor::ExtractOp op) { auto constantStr = getStringForConstant(value); return constantStr.value_or(variableNames->getNameForValue(value)); }); - os << "]"; - os << ";\n"; + os << "];\n"; return success(); } @@ -1347,13 +1343,6 @@ LogicalResult OpenFhePkeEmitter::printOperation(tensor::FromElementsOp op) { return success(); } -LogicalResult OpenFhePkeEmitter::printOperation( - lwe::ReinterpretApplicationDataOp op) { - emitAutoAssignPrefix(op.getResult()); - os << variableNames->getNameForValue(op.getInput()) << ";\n"; - return success(); -} - LogicalResult OpenFhePkeEmitter::printOperation( openfhe::MakePackedPlaintextOp op) { std::string inputVarName = variableNames->getNameForValue(op.getValue()); @@ -1465,27 +1454,26 @@ FailureOr> getNonUnitDimension( return std::make_pair(nonUnitIndex, shape[nonUnitIndex]); } -LogicalResult OpenFhePkeEmitter::printOperation(lwe::RLWEDecodeOp op) { +LogicalResult OpenFhePkeEmitter::decodeCore(Location loc, Value input, + Value result, bool isCKKS) { // In OpenFHE a plaintext is already decoded by decrypt. The internal // OpenFHE implementation is simple enough (and dependent on // currently-hard-coded encoding choices) that we will eventually need to // work at a lower level of the API to support this operation properly. - bool isCKKS = llvm::isa(op.getEncoding()); - auto tensorTy = dyn_cast(op.getResult().getType()); + auto tensorTy = dyn_cast(result.getType()); if (tensorTy) { auto nonUnitDim = getNonUnitDimension(tensorTy); if (failed(nonUnitDim)) { - return emitError(op.getLoc(), "Only 1D tensors supported"); + return emitError(loc, "Only 1D tensors supported"); } // OpenFHE plaintexts must be manually resized to the decoded output size // via plaintext->SetLength(); auto size = nonUnitDim.value().second; - auto inputVarName = variableNames->getNameForValue(op.getInput()); + auto inputVarName = variableNames->getNameForValue(input); os << inputVarName << "->SetLength(" << size << ");\n"; // Get the packed values in OpenFHE's type (vector of int_64t/complex/etc) - std::string tmpVar = - variableNames->getNameForValue(op.getResult()) + "_cast"; + std::string tmpVar = variableNames->getNameForValue(result) + "_cast"; os << "const auto& " << tmpVar << " = "; if (isCKKS) { os << inputVarName << "->GetCKKSPackedValue();\n"; @@ -1494,8 +1482,8 @@ LogicalResult OpenFhePkeEmitter::printOperation(lwe::RLWEDecodeOp op) { } // Convert to the intended type defined by the program - auto outputVarName = variableNames->getNameForValue(op.getResult()); - if (failed(emitType(tensorTy, op->getLoc()))) { + auto outputVarName = variableNames->getNameForValue(result); + if (failed(emitType(tensorTy, loc))) { return failure(); } if (isCKKS) { @@ -1514,9 +1502,9 @@ LogicalResult OpenFhePkeEmitter::printOperation(lwe::RLWEDecodeOp op) { } // By convention, a plaintext stores a scalar value in index 0 - auto result = emitTypedAssignPrefix(op.getResult(), op->getLoc()); - if (failed(result)) return result; - os << variableNames->getNameForValue(op.getInput()); + auto res = emitTypedAssignPrefix(result, loc); + if (failed(res)) return res; + os << variableNames->getNameForValue(input); if (isCKKS) { os << "->GetCKKSPackedValue()[0].real();\n"; } else { @@ -1525,6 +1513,16 @@ LogicalResult OpenFhePkeEmitter::printOperation(lwe::RLWEDecodeOp op) { return success(); } +LogicalResult OpenFhePkeEmitter::printOperation(DecodeOp op) { + return decodeCore(op.getLoc(), op.getInput(), op.getResult(), + /*isCKKS=*/false); +} + +LogicalResult OpenFhePkeEmitter::printOperation(DecodeCKKSOp op) { + return decodeCore(op.getLoc(), op.getInput(), op.getResult(), + /*isCKKS=*/true); +} + LogicalResult OpenFhePkeEmitter::printOperation(EncryptOp op) { return printEvalMethod(op.getResult(), op.getCryptoContext(), {op.getEncryptionKey(), op.getPlaintext()}, "Encrypt"); diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.h b/lib/Target/OpenFhePke/OpenFhePkeEmitter.h index f0353af0a8..516da5d543 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.h +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.h @@ -15,7 +15,6 @@ // IWYU pragma: end_keep #include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" -#include "lib/Dialect/LWE/IR/LWEOps.h" #include "lib/Dialect/Openfhe/IR/OpenfheOps.h" #include "lib/Target/OpenFhePke/OpenFheUtils.h" #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project @@ -126,14 +125,13 @@ class OpenFhePkeEmitter { LogicalResult printOperation(::mlir::func::FuncOp op); LogicalResult printOperation(::mlir::func::CallOp op); LogicalResult printOperation(::mlir::func::ReturnOp op); - LogicalResult printOperation(::mlir::heir::lwe::RLWEDecodeOp op); - LogicalResult printOperation( - ::mlir::heir::lwe::ReinterpretApplicationDataOp op); LogicalResult printOperation(AddOp op); LogicalResult printOperation(AddPlainOp op); LogicalResult printOperation(AutomorphOp op); LogicalResult printOperation(BootstrapOp op); LogicalResult printOperation(DecryptOp op); + LogicalResult printOperation(DecodeOp op); + LogicalResult printOperation(DecodeCKKSOp op); LogicalResult printOperation(EncryptOp op); LogicalResult printOperation(GenParamsOp op); LogicalResult printOperation(FastRotationOp op); @@ -166,6 +164,8 @@ class OpenFhePkeEmitter { std::string_view op); LogicalResult printBinaryOp(Operation* op, ::mlir::Value lhs, ::mlir::Value rhs, std::string_view opName); + LogicalResult decodeCore(::mlir::Location loc, ::mlir::Value input, + ::mlir::Value result, bool isCKKS); // A helper for a special case of ExtractSliceOp LogicalResult extractRowFromMatrix(tensor::ExtractSliceOp op); diff --git a/lib/Target/OpenFhePke/OpenFheUtils.cpp b/lib/Target/OpenFhePke/OpenFheUtils.cpp index a8f2d8ac36..395c144abb 100644 --- a/lib/Target/OpenFhePke/OpenFheUtils.cpp +++ b/lib/Target/OpenFhePke/OpenFheUtils.cpp @@ -3,7 +3,6 @@ #include #include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" -#include "lib/Dialect/LWE/IR/LWETypes.h" #include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" #include "lib/Target/OpenFhePke/OpenFhePkeTemplates.h" #include "lib/Utils/TargetUtils.h" @@ -50,12 +49,11 @@ FailureOr convertType(Type type, Location loc, bool constant) { .Case( [&](auto ty) { return std::string("CryptoContextT"); }) .Case([&](auto ty) { return std::string("CCParamsT"); }) - .Case([&](auto ty) { + .Case([&](auto ty) { return constant ? std::string("CiphertextT") : std::string("MutableCiphertextT"); }) - .Case( - [&](auto ty) { return std::string("Plaintext"); }) + .Case([&](auto ty) { return std::string("Plaintext"); }) .Case( [&](auto ty) { return std::string("EvalKeyT"); }) .Case( diff --git a/lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.cpp b/lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.cpp new file mode 100644 index 0000000000..356384a411 --- /dev/null +++ b/lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.cpp @@ -0,0 +1,61 @@ +#include "lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.h" + +#include "lib/Analysis/MulDepthAnalysis/MulDepthAnalysis.h" +#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" +#include "lib/Utils/AttributeUtils.h" +#include "lib/Utils/Utils.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/Utils.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +#define DEBUG_TYPE "annotate-muldepth" + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_ANNOTATEMULDEPTH +#include "lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.h.inc" + +struct AnnotateMulDepth : impl::AnnotateMulDepthBase { + using AnnotateMulDepthBase::AnnotateMulDepthBase; + + void runOnOperation() override { + DataFlowSolver solver; + dataflow::loadBaselineAnalyses(solver); + solver.load(); + solver.load(); + + auto result = solver.initializeAndRun(getOperation()); + + if (failed(result)) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + return; + } + + walkValues(getOperation(), [&](Value value) { + auto* lattice = solver.lookupState(value); + LLVM_DEBUG(llvm::dbgs() << "lattice for value " << value << " : "); + if (!lattice) { + LLVM_DEBUG(llvm::dbgs() << "mul depth lattice undefined\n"); + return; + } + auto& state = lattice->getValue(); + if (!state.isInitialized()) { + LLVM_DEBUG(llvm::dbgs() << "mul depth lattice uninitialized\n"); + return; + } + OpBuilder b(value.getContext()); + setAttributeAssociatedWith(value, "secret.mul_depth", + b.getIndexAttr(state.getMulDepth())); + }); + } +}; + +} // namespace heir +} // namespace mlir diff --git a/lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.h b/lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.h new file mode 100644 index 0000000000..1e1fcea3f8 --- /dev/null +++ b/lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.h @@ -0,0 +1,18 @@ +#ifndef LIB_TRANSFORMS_ANNOTATEMULDEPTH_ANNOTATEMULDEPTH_H_ +#define LIB_TRANSFORMS_ANNOTATEMULDEPTH_ANNOTATEMULDEPTH_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // LIB_TRANSFORMS_ANNOTATEMULDEPTH_ANNOTATEMULDEPTH_H_ diff --git a/lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.td b/lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.td new file mode 100644 index 0000000000..4ba51fe478 --- /dev/null +++ b/lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.td @@ -0,0 +1,16 @@ +#ifndef LIB_TRANSFORMS_ANNOTATEMULDEPTH_ANNOTATEMULDEPTH_TD_ +#define LIB_TRANSFORMS_ANNOTATEMULDEPTH_ANNOTATEMULDEPTH_TD_ + +include "mlir/Pass/PassBase.td" + +def AnnotateMulDepth : Pass<"annotate-muldepth"> { + let summary = "Annotate multiplicative depth in the IR"; + let description = [{ + Debugging helper that runs the multiplicative depth analysis and annotates + the IR with the results. + + (* example filepath=tests/Transforms/annotate_muldepth/doctest.mlir *) + }]; +} + +#endif // LIB_TRANSFORMS_ANNOTATEMULDEPTH_ANNOTATEMULDEPTH_TD_ diff --git a/lib/Transforms/AnnotateMulDepth/BUILD b/lib/Transforms/AnnotateMulDepth/BUILD new file mode 100644 index 0000000000..60c47552a2 --- /dev/null +++ b/lib/Transforms/AnnotateMulDepth/BUILD @@ -0,0 +1,32 @@ +load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") +load("@rules_cc//cc:cc_library.bzl", "cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "AnnotateMulDepth", + srcs = ["AnnotateMulDepth.cpp"], + hdrs = [ + "AnnotateMulDepth.h", + ], + deps = [ + ":pass_inc_gen", + "@heir//lib/Analysis/MulDepthAnalysis", + "@heir//lib/Analysis/SecretnessAnalysis", + "@heir//lib/Utils", + "@heir//lib/Utils:AttributeUtils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + +add_heir_transforms( + generated_target_name = "pass_inc_gen", + pass_name = "AnnotateMulDepth", +) diff --git a/lib/Transforms/SecretInsertMgmt/Pipeline.cpp b/lib/Transforms/SecretInsertMgmt/Pipeline.cpp index 99374c2387..d72a649189 100644 --- a/lib/Transforms/SecretInsertMgmt/Pipeline.cpp +++ b/lib/Transforms/SecretInsertMgmt/Pipeline.cpp @@ -99,7 +99,12 @@ void insertModReduceBeforeOrAfterMult(Operation* top, DataFlowSolver& solver, bool beforeMulIncludeFirstMul, bool includeFloats) { MLIRContext* ctx = top->getContext(); - LLVM_DEBUG(llvm::dbgs() << "Insert ModReduce Before/After Mult\n"); + LLVM_DEBUG({ + auto when = "before mul"; + if (afterMul) when = "after mul"; + if (beforeMulIncludeFirstMul) when = "before mul + before first mul"; + llvm::dbgs() << "Insert ModReduce " << when << "\n"; + }); RewritePatternSet patterns(ctx); if (afterMul) { diff --git a/lib/Utils/ConversionUtils.cpp b/lib/Utils/ConversionUtils.cpp index c107a06c16..cb3050124a 100644 --- a/lib/Utils/ConversionUtils.cpp +++ b/lib/Utils/ConversionUtils.cpp @@ -270,6 +270,12 @@ void addStructuralConversionPatterns(TypeConverter& typeConverter, scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, target); + target.addDynamicallyLegalOp( + [&](Operation* op) { return typeConverter.isLegal(op); }); + patterns + .add, ConvertAny>( + typeConverter, patterns.getContext()); + target.markUnknownOpDynamicallyLegal([&](Operation* op) { // These rules are needed to handle interface ops that are not directly // registered as legal/illegal with the target. diff --git a/tests/Dialect/BGV/Conversions/bgv_to_openfhe/bgv_to_openfhe.mlir b/tests/Dialect/BGV/Conversions/bgv_to_openfhe/bgv_to_openfhe.mlir index dbf6795d22..0879399bf8 100644 --- a/tests/Dialect/BGV/Conversions/bgv_to_openfhe/bgv_to_openfhe.mlir +++ b/tests/Dialect/BGV/Conversions/bgv_to_openfhe/bgv_to_openfhe.mlir @@ -33,7 +33,7 @@ // CHECK: module module { // CHECK: @test_ops - // CHECK-SAME: ([[C:%.+]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:.*lwe_ciphertext.*]], [[Y:%.+]]: [[T]], [[Z:%.+]]: [[P:.*lwe_plaintext.[^)]*]]) + // CHECK-SAME: ([[C:%.+]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:!openfhe.ciphertext]], [[Y:%.+]]: [[T]], [[Z:%.+]]: [[P:!openfhe.plaintext]]) func.func @test_ops(%x : !ct, %y : !ct, %z : !pt) -> (!ct, !ct, !ct, !ct_D3, !ct, !ct, !ct, !ct) { // CHECK: %[[v1:.*]] = openfhe.negate [[C]], %[[x1:.*]] : ([[S]], [[T]]) -> [[T]] %negate = bgv.negate %x : !ct @@ -56,7 +56,7 @@ module { } // CHECK: @test_relin - // CHECK-SAME: ([[C:.*]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:.*lwe_ciphertext.*]]) + // CHECK-SAME: ([[C:.*]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:!openfhe.ciphertext]]) func.func @test_relin(%x : !ct_D4) -> !ct { // CHECK: %[[v6:.*]] = openfhe.relin [[C]], %[[x6:.*]]: ([[S]], [[T]]) -> [[T2:.*]] %relin = bgv.relinearize %x { @@ -66,9 +66,9 @@ module { } // CHECK: @test_modswitch - // CHECK-SAME: ([[C:.*]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:.*lwe_ciphertext.*]]) -> [[T1:.*]] { + // CHECK-SAME: ([[C:.*]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:!openfhe.ciphertext]]) -> [[T]] func.func @test_modswitch(%x : !ct) -> !ct_L0 { - // CHECK: %[[v7:.*]] = openfhe.mod_reduce [[C]], %[[x7:.*]] : ([[S]], [[T]]) -> [[T1]] + // CHECK: %[[v7:.*]] = openfhe.mod_reduce [[C]], %[[x7:.*]] : ([[S]], [[T]]) -> [[T]] %mod_switch = bgv.modulus_switch %x { to_ring=#ring_rns_L0_1_x1024_ }: !ct -> !ct_L0 return %mod_switch : !ct_L0 } diff --git a/tests/Dialect/BGV/Conversions/bgv_to_openfhe/cast.mlir b/tests/Dialect/BGV/Conversions/bgv_to_openfhe/cast.mlir index 55277c57bd..d002775bf1 100644 --- a/tests/Dialect/BGV/Conversions/bgv_to_openfhe/cast.mlir +++ b/tests/Dialect/BGV/Conversions/bgv_to_openfhe/cast.mlir @@ -22,12 +22,12 @@ //The function is adapted from the BGV form of the simple_sum.mlir test // CHECK: @encode_i16 // CHECK-SAME: %[[cc:.*]]: !openfhe.crypto_context -// CHECK-SAME: %[[arg16:.*]]: tensor<32xi16> +// CHECK-SAME: %[[arg0:.*]]: tensor<32xi16> func.func @encode_i16(%arg0: tensor<32xi16>, %arg1: !pk) -> !pt_i16 { %0 = lwe.rlwe_encode %arg0 {encoding = #full_crt_packing_encoding, ring = #ring_Z65537_i64_1_x32_} : tensor<32xi16> -> !pt_i16 - // CHECK: %[[v0:.*]] = arith.extsi %[[arg16]] : tensor<32xi16> to tensor<32xi64> - // CHECK: openfhe.make_packed_plaintext %[[cc]], %[[v0]] {{.*}} tensor<32xi64>) -> !lwe.lwe_plaintext{{.*}} tensor<32xi16> - // CHECK-NOT: openfhe.make_packed_plaintext {{.*}} tensor<32xi16> -> !lwe.lwe_plaintext{{.*}} tensor<32xi16> + // CHECK: %[[v0:.*]] = arith.extsi %[[arg0]] : tensor<32xi16> to tensor<32xi64> + // CHECK: openfhe.make_packed_plaintext %[[cc]], %[[v0]] : (!openfhe.crypto_context, tensor<32xi64>) -> !openfhe.plaintext + // CHECK-NOT: openfhe.make_packed_plaintext {{.*}} tensor<32xi16> return %0 : !pt_i16 } @@ -37,14 +37,16 @@ func.func @encode_i16(%arg0: tensor<32xi16>, %arg1: !pk) -> !pt_i16 { func.func @encode_i32(%arg0: tensor<32xi32>, %arg1: !pk) -> !pt_i32 { %0 = lwe.rlwe_encode %arg0 {encoding = #full_crt_packing_encoding, ring = #ring_Z65537_i64_1_x32_} : tensor<32xi32> -> !pt_i32 // CHECK: %[[v0:.*]] = arith.extsi %[[arg0]] : tensor<32xi32> to tensor<32xi64> - // CHECK: openfhe.make_packed_plaintext %[[cc]], %[[v0]] {{.*}} tensor<32xi64>) -> !lwe.lwe_plaintext{{.*}} tensor<32xi32> - // CHECK-NOT: openfhe.make_packed_plaintext {{.*}} : tensor<32xi32> -> !lwe.lwe_plaintext{{.*}} tensor<32xi32> + // CHECK: openfhe.make_packed_plaintext %[[cc]], %[[v0]] : (!openfhe.crypto_context, tensor<32xi64>) -> !openfhe.plaintext + // CHECK-NOT: openfhe.make_packed_plaintext {{.*}} : tensor<32xi32> return %0 : !pt_i32 } // CHECK: @encode_i64 +// CHECK-SAME: %[[cc:.*]]: !openfhe.crypto_context +// CHECK-SAME: %[[arg0:.*]]: tensor<32xi64> func.func @encode_i64(%arg0: tensor<32xi64>, %arg1: !pk) -> !pt_i64 { %0 = lwe.rlwe_encode %arg0 {encoding = #full_crt_packing_encoding, ring = #ring_Z65537_i64_1_x32_} : tensor<32xi64> -> !pt_i64 - // CHECK: openfhe.make_packed_plaintext {{.*}} tensor<32xi64>) -> !lwe.lwe_plaintext{{.*}} tensor<32xi64> + // CHECK: openfhe.make_packed_plaintext %[[cc]], %[[arg0]] : (!openfhe.crypto_context, tensor<32xi64>) -> !openfhe.plaintext return %0 : !pt_i64 } diff --git a/tests/Dialect/BGV/Conversions/bgv_to_openfhe/linear_polynomial.mlir b/tests/Dialect/BGV/Conversions/bgv_to_openfhe/linear_polynomial.mlir index 3f5e1cfc06..139e20e363 100644 --- a/tests/Dialect/BGV/Conversions/bgv_to_openfhe/linear_polynomial.mlir +++ b/tests/Dialect/BGV/Conversions/bgv_to_openfhe/linear_polynomial.mlir @@ -24,7 +24,7 @@ !ct_sq_ty = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_D3_, key = #key, modulus_chain = #modulus_chain_L5_C0_> // CHECK: @linear_polynomial -// CHECK-SAME: (%[[cc:.*]]: [[cc_ty:.*crypto_context]], %[[arg0:.*]]: [[T:.*lwe_ciphertext.*]], %[[arg1:.*]]: [[T]], %[[arg2:.*]]: [[T]], %[[arg3:.*]]: [[T]]) -> [[T]] { +// CHECK-SAME: (%[[cc:.*]]: [[cc_ty:!openfhe.crypto_context]], %[[arg0:.*]]: [[T:!openfhe.ciphertext]], %[[arg1:.*]]: [[T]], %[[arg2:.*]]: [[T]], %[[arg3:.*]]: [[T]]) -> [[T]] func.func @linear_polynomial(%arg0: !ct_ty, %arg1: !ct_ty, %arg2: !ct_ty, %arg3: !ct_ty) -> !ct_ty { // CHECK: %[[v0:.*]] = openfhe.mul_no_relin %[[cc]], %[[arg0]], %[[arg2]] %0 = bgv.mul %arg0, %arg2 : (!ct_ty, !ct_ty) -> !ct_sq_ty diff --git a/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/cast.mlir b/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/cast.mlir index d59caa1021..698892c075 100644 --- a/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/cast.mlir +++ b/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/cast.mlir @@ -26,8 +26,8 @@ func.func @encode_i16(%arg0: tensor<32xi16>, %arg1: !pk) -> !pt_i16 { %0 = lwe.rlwe_encode %arg0 {encoding = #inverse_canonical_encoding, ring = #ring_Z65537_i64_1_x32_} : tensor<32xi16> -> !pt_i16 // CHECK: %[[v0:.*]] = arith.extsi %[[arg16]] : tensor<32xi16> to tensor<32xi64> - // CHECK: openfhe.make_ckks_packed_plaintext %[[cc]], %[[v0]] {{.*}} tensor<32xi64>) -> !lwe.lwe_plaintext{{.*}} tensor<32xi16> - // CHECK-NOT: openfhe.make_ckks_packed_plaintext {{.*}} tensor<32xi16> -> !lwe.lwe_plaintext{{.*}} tensor<32xi16> + // CHECK: openfhe.make_ckks_packed_plaintext %[[cc]], %[[v0]] {{.*}} tensor<32xi64>) -> !openfhe.plaintext + // CHECK-NOT: openfhe.make_ckks_packed_plaintext {{.*}} tensor<32xi16> -> !openfhe.plaintext return %0 : !pt_i16 } @@ -37,14 +37,14 @@ func.func @encode_i16(%arg0: tensor<32xi16>, %arg1: !pk) -> !pt_i16 { func.func @encode_i32(%arg0: tensor<32xi32>, %arg1: !pk) -> !pt_i32 { %0 = lwe.rlwe_encode %arg0 {encoding = #inverse_canonical_encoding, ring = #ring_Z65537_i64_1_x32_} : tensor<32xi32> -> !pt_i32 // CHECK: %[[v0:.*]] = arith.extsi %[[arg0]] : tensor<32xi32> to tensor<32xi64> - // CHECK: openfhe.make_ckks_packed_plaintext %[[cc]], %[[v0]] {{.*}} tensor<32xi64>) -> !lwe.lwe_plaintext{{.*}} tensor<32xi32> - // CHECK-NOT: openfhe.make_ckks_packed_plaintext {{.*}} : tensor<32xi32> -> !lwe.lwe_plaintext{{.*}} tensor<32xi32> + // CHECK: openfhe.make_ckks_packed_plaintext %[[cc]], %[[v0]] {{.*}} tensor<32xi64>) -> !openfhe.plaintext + // CHECK-NOT: openfhe.make_ckks_packed_plaintext {{.*}} : tensor<32xi32> -> !openfhe.plaintext return %0 : !pt_i32 } // CHECK: @encode_i64 func.func @encode_i64(%arg0: tensor<32xi64>, %arg1: !pk) -> !pt_i64 { %0 = lwe.rlwe_encode %arg0 {encoding = #inverse_canonical_encoding, ring = #ring_Z65537_i64_1_x32_} : tensor<32xi64> -> !pt_i64 - // CHECK: openfhe.make_ckks_packed_plaintext {{.*}} tensor<32xi64>) -> !lwe.lwe_plaintext{{.*}} tensor<32xi64> + // CHECK: openfhe.make_ckks_packed_plaintext {{.*}} tensor<32xi64>) -> !openfhe.plaintext return %0 : !pt_i64 } diff --git a/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/ckks_to_openfhe.mlir b/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/ckks_to_openfhe.mlir index c1642c5e19..f5aacb51b5 100644 --- a/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/ckks_to_openfhe.mlir +++ b/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/ckks_to_openfhe.mlir @@ -34,7 +34,7 @@ // CHECK: module module { // CHECK: @test_ops - // CHECK-SAME: ([[C:%.+]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:.*lwe_ciphertext.*]], [[Y:%.+]]: [[T]], [[Z:%.+]]: [[P:.*lwe_plaintext[^)]*]]) + // CHECK-SAME: ([[C:%.+]]: [[S:!openfhe.crypto_context]], [[X:%.+]]: [[T:!openfhe.ciphertext]], [[Y:%.+]]: [[T]], [[Z:%.+]]: [[P:!openfhe.plaintext]]) func.func @test_ops(%x : !ct, %y : !ct, %z : !pt) -> (!ct, !ct, !ct, !ct_D3, !ct, !ct, !ct, !ct) { // CHECK: %[[v1:.*]] = openfhe.negate [[C]], %[[x1:.*]] : ([[S]], [[T]]) -> [[T]] %negate = ckks.negate %x : !ct @@ -57,7 +57,7 @@ module { } // CHECK: @test_relin - // CHECK-SAME: ([[C:.*]]: [[S:.*crypto_context]], [[X:%.+]]: [[T:.*lwe_ciphertext.*]]) + // CHECK-SAME: ([[C:.*]]: [[S:!openfhe.crypto_context]], [[X:%.+]]: [[T:!openfhe.ciphertext]]) func.func @test_relin(%x : !ct_D4) -> !ct { // CHECK: %[[v6:.*]] = openfhe.relin [[C]], %[[x6:.*]]: ([[S]], [[T]]) -> [[T2:.*]] %relin = ckks.relinearize %x { diff --git a/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/linear_polynomial.mlir b/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/linear_polynomial.mlir index 0320cf15c5..8dc72c8709 100644 --- a/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/linear_polynomial.mlir +++ b/tests/Dialect/CKKS/Conversions/ckks_to_openfhe/linear_polynomial.mlir @@ -24,7 +24,7 @@ !ct_sq_ty = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_D3_, key = #key, modulus_chain = #modulus_chain_L5_C0_> // CHECK: @linear_polynomial -// CHECK-SAME: (%[[cc:.*]]: [[cc_ty:.*crypto_context]], %[[arg0:.*]]: [[T:.*lwe_ciphertext.*]], %[[arg1:.*]]: [[T]], %[[arg2:.*]]: [[T]], %[[arg3:.*]]: [[T]]) -> [[T]] { +// CHECK-SAME: (%[[cc:.*]]: [[cc_ty:!openfhe.crypto_context]], %[[arg0:.*]]: [[T:!openfhe.ciphertext]], %[[arg1:.*]]: [[T]], %[[arg2:.*]]: [[T]], %[[arg3:.*]]: [[T]]) -> [[T]] func.func @linear_polynomial(%arg0: !ct_ty, %arg1: !ct_ty, %arg2: !ct_ty, %arg3: !ct_ty) -> !ct_ty { // CHECK: %[[v0:.*]] = openfhe.mul_no_relin %[[cc]], %[[arg0]], %[[arg2]] %0 = ckks.mul %arg0, %arg2 : (!ct_ty, !ct_ty) -> !ct_sq_ty diff --git a/tests/Dialect/LWE/Conversions/lwe_to_openfhe/affine.mlir b/tests/Dialect/LWE/Conversions/lwe_to_openfhe/affine.mlir new file mode 100644 index 0000000000..9f20f2e770 --- /dev/null +++ b/tests/Dialect/LWE/Conversions/lwe_to_openfhe/affine.mlir @@ -0,0 +1,118 @@ +// RUN: heir-opt --lwe-to-openfhe %s | FileCheck %s + +// CHECK-NOT: lwe + +!Z35184269590529_i64 = !mod_arith.int<35184269590529 : i64> +!Z35184270114817_i64 = !mod_arith.int<35184270114817 : i64> +!Z35184272736257_i64 = !mod_arith.int<35184272736257 : i64> +!Z35184275619841_i64 = !mod_arith.int<35184275619841 : i64> +!Z35184279552001_i64 = !mod_arith.int<35184279552001 : i64> +!Z35184281387009_i64 = !mod_arith.int<35184281387009 : i64> +!Z35184284270593_i64 = !mod_arith.int<35184284270593 : i64> +!Z35184290562049_i64 = !mod_arith.int<35184290562049 : i64> +!Z35184297639937_i64 = !mod_arith.int<35184297639937 : i64> +!Z35184301047809_i64 = !mod_arith.int<35184301047809 : i64> +!Z35184306290689_i64 = !mod_arith.int<35184306290689 : i64> +!Z35184307077121_i64 = !mod_arith.int<35184307077121 : i64> +!Z35184314941441_i64 = !mod_arith.int<35184314941441 : i64> +!Z35184316776449_i64 = !mod_arith.int<35184316776449 : i64> +!Z35184318087169_i64 = !mod_arith.int<35184318087169 : i64> +!Z35184320708609_i64 = !mod_arith.int<35184320708609 : i64> +!Z35184329097217_i64 = !mod_arith.int<35184329097217 : i64> +!Z35184330145793_i64 = !mod_arith.int<35184330145793 : i64> +!Z35184339320833_i64 = !mod_arith.int<35184339320833 : i64> +!Z35184345088001_i64 = !mod_arith.int<35184345088001 : i64> +!Z35184350330881_i64 = !mod_arith.int<35184350330881 : i64> +!Z35184365273089_i64 = !mod_arith.int<35184365273089 : i64> +!Z35184376545281_i64 = !mod_arith.int<35184376545281 : i64> +!Z35184377331713_i64 = !mod_arith.int<35184377331713 : i64> +!Z35184385196033_i64 = !mod_arith.int<35184385196033 : i64> +!Z35184399351809_i64 = !mod_arith.int<35184399351809 : i64> +!Z35184404070401_i64 = !mod_arith.int<35184404070401 : i64> +!Z35184410361857_i64 = !mod_arith.int<35184410361857 : i64> +!Z35184414031873_i64 = !mod_arith.int<35184414031873 : i64> +!Z35184415080449_i64 = !mod_arith.int<35184415080449 : i64> +!Z35184415866881_i64 = !mod_arith.int<35184415866881 : i64> +!Z35184423731201_i64 = !mod_arith.int<35184423731201 : i64> +!Z35184430022657_i64 = !mod_arith.int<35184430022657 : i64> +!Z35184430809089_i64 = !mod_arith.int<35184430809089 : i64> +!Z35184436314113_i64 = !mod_arith.int<35184436314113 : i64> +!Z35184440246273_i64 = !mod_arith.int<35184440246273 : i64> +!Z35184440770561_i64 = !mod_arith.int<35184440770561 : i64> +!Z35184446537729_i64 = !mod_arith.int<35184446537729 : i64> +!Z35184452567041_i64 = !mod_arith.int<35184452567041 : i64> +!Z35184454402049_i64 = !mod_arith.int<35184454402049 : i64> +!Z35184454926337_i64 = !mod_arith.int<35184454926337 : i64> +!Z35184463839233_i64 = !mod_arith.int<35184463839233 : i64> +!Z35184465412097_i64 = !mod_arith.int<35184465412097 : i64> +!Z35184474587137_i64 = !mod_arith.int<35184474587137 : i64> +!Z35184478519297_i64 = !mod_arith.int<35184478519297 : i64> +!Z36028797005856769_i64 = !mod_arith.int<36028797005856769 : i64> +#inverse_canonical_encoding = #lwe.inverse_canonical_encoding +#key = #lwe.key<> +#layout = #tensor_ext.layout<"{ [i0, i1] -> [ct, slot] : i0 = 0 and ct = 0 and (-i1 + slot) mod 16 = 0 and 0 <= i1 <= 9 and 0 <= slot <= 1023 }"> +#modulus_chain_L45_C45 = #lwe.modulus_chain, current = 45> +#ring_f64_1_x1024 = #polynomial.ring> +!rns_L45 = !rns.rns +#original_type = #tensor_ext.original_type, layout = #layout> +!pt = !lwe.lwe_plaintext>, plaintext_space = > +#ring_rns_L45_1_x1024 = #polynomial.ring> +#ciphertext_space_L45 = #lwe.ciphertext_space +!ct_L45 = !lwe.lwe_ciphertext>, plaintext_space = , ciphertext_space = #ciphertext_space_L45, key = #key, modulus_chain = #modulus_chain_L45_C45> +module attributes {backend.openfhe, ckks.schemeParam = #ckks.scheme_param, scheme.ckks} { + func.func @lenet(%arg0: tensor<1x!ct_L45> {tensor_ext.original_type = #tensor_ext.original_type, layout = #tensor_ext.layout<"{ [i0, i1, i2, i3] -> [ct, slot] : i0 = 0 and i1 = 0 and ct = 0 and (-32i2 - i3 + slot) mod 1024 = 0 and 0 <= i2 <= 31 and 0 <= i3 <= 31 and 0 <= slot <= 1023 }">>}, %arg1: tensor<1x!ct_L45>, %arg2: tensor<1x!ct_L45>, %arg3: tensor<1x!ct_L45>, %arg4: tensor<1x!ct_L45>, %arg5: tensor<1x!ct_L45>, %arg6: tensor<1x1024xf32>, %ct: !ct_L45, %pt: !pt, %arg7: tensor<1x!ct_L45>) -> (tensor<1x!ct_L45> {tensor_ext.original_type = #original_type}) { + %cst = arith.constant dense<0.000000e+00> : tensor<1x1024xf32> + %cst_0 = arith.constant dense<5.000000e+00> : tensor<6x1024xf32> + %c0 = arith.constant 0 : index + %cst_1 = arith.constant 0.235571414 : f32 + %0 = tensor.empty() : tensor<6x!ct_L45> + %c1_i32 = arith.constant 1 : i32 + %c784_i32 = arith.constant 784 : i32 + %c6_i32 = arith.constant 6 : i32 + %c0_i32 = arith.constant 0 : i32 + %inserted_slice = tensor.insert_slice %arg1 into %0[0] [1] [1] : tensor<1x!ct_L45> into tensor<6x!ct_L45> + %inserted_slice_2 = tensor.insert_slice %arg2 into %inserted_slice[1] [1] [1] : tensor<1x!ct_L45> into tensor<6x!ct_L45> + %inserted_slice_3 = tensor.insert_slice %arg3 into %inserted_slice_2[2] [1] [1] : tensor<1x!ct_L45> into tensor<6x!ct_L45> + %inserted_slice_4 = tensor.insert_slice %arg4 into %inserted_slice_3[3] [1] [1] : tensor<1x!ct_L45> into tensor<6x!ct_L45> + %inserted_slice_5 = tensor.insert_slice %arg5 into %inserted_slice_4[4] [1] [1] : tensor<1x!ct_L45> into tensor<6x!ct_L45> + %ct_6 = lwe.rmul_plain %ct, %pt : (!ct_L45, !pt) -> !ct_L45 + %1 = arith.mulf %arg6, %cst fastmath : tensor<1x1024xf32> + %extracted_slice = tensor.extract_slice %1[0, 0] [1, 1024] [1, 1] : tensor<1x1024xf32> to tensor<1024xf32> + %pt_7 = lwe.rlwe_encode %extracted_slice {encoding = #inverse_canonical_encoding, ring = #ring_f64_1_x1024} : tensor<1024xf32> -> !pt + %ct_8 = lwe.radd_plain %ct_6, %pt_7 : (!ct_L45, !pt) -> !ct_L45 + %inserted = tensor.insert %ct_8 into %arg7[%c0] : tensor<1x!ct_L45> + %inserted_slice_9 = tensor.insert_slice %inserted into %inserted_slice_5[5] [1] [1] : tensor<1x!ct_L45> into tensor<6x!ct_L45> + %2 = scf.for %arg8 = %c0_i32 to %c6_i32 step %c1_i32 iter_args(%arg9 = %cst_0) -> (tensor<6x1024xf32>) : i32 { + %4 = scf.for %arg10 = %c0_i32 to %c784_i32 step %c1_i32 iter_args(%arg11 = %arg9) -> (tensor<6x1024xf32>) : i32 { + %5 = arith.index_cast %arg8 : i32 to index + %6 = arith.index_cast %arg10 : i32 to index + %inserted_23 = tensor.insert %cst_1 into %arg11[%5, %6] : tensor<6x1024xf32> + scf.yield %inserted_23 : tensor<6x1024xf32> + } + scf.yield %4 : tensor<6x1024xf32> + } + %extracted_slice_10 = tensor.extract_slice %2[0, 0] [1, 1024] [1, 1] : tensor<6x1024xf32> to tensor<1024xf32> + %pt_11 = lwe.rlwe_encode %extracted_slice_10 {encoding = #inverse_canonical_encoding, ring = #ring_f64_1_x1024} : tensor<1024xf32> -> !pt + %extracted_slice_12 = tensor.extract_slice %2[1, 0] [1, 1024] [1, 1] : tensor<6x1024xf32> to tensor<1024xf32> + %pt_13 = lwe.rlwe_encode %extracted_slice_12 {encoding = #inverse_canonical_encoding, ring = #ring_f64_1_x1024} : tensor<1024xf32> -> !pt + %extracted_slice_14 = tensor.extract_slice %2[2, 0] [1, 1024] [1, 1] : tensor<6x1024xf32> to tensor<1024xf32> + %pt_15 = lwe.rlwe_encode %extracted_slice_14 {encoding = #inverse_canonical_encoding, ring = #ring_f64_1_x1024} : tensor<1024xf32> -> !pt + %extracted_slice_16 = tensor.extract_slice %2[3, 0] [1, 1024] [1, 1] : tensor<6x1024xf32> to tensor<1024xf32> + %pt_17 = lwe.rlwe_encode %extracted_slice_16 {encoding = #inverse_canonical_encoding, ring = #ring_f64_1_x1024} : tensor<1024xf32> -> !pt + %extracted_slice_18 = tensor.extract_slice %2[4, 0] [1, 1024] [1, 1] : tensor<6x1024xf32> to tensor<1024xf32> + %pt_19 = lwe.rlwe_encode %extracted_slice_18 {encoding = #inverse_canonical_encoding, ring = #ring_f64_1_x1024} : tensor<1024xf32> -> !pt + %extracted_slice_20 = tensor.extract_slice %2[5, 0] [1, 1024] [1, 1] : tensor<6x1024xf32> to tensor<1024xf32> + %pt_21 = lwe.rlwe_encode %extracted_slice_20 {encoding = #inverse_canonical_encoding, ring = #ring_f64_1_x1024} : tensor<1024xf32> -> !pt + %from_elements = tensor.from_elements %pt_11, %pt_13, %pt_15, %pt_17, %pt_19, %pt_21 : tensor<6x!pt> + %3 = affine.for %arg8 = 0 to 6 iter_args(%arg9 = %0) -> (tensor<6x!ct_L45>) { + %extracted_23 = tensor.extract %inserted_slice_9[%arg8] : tensor<6x!ct_L45> + %extracted_24 = tensor.extract %from_elements[%arg8] : tensor<6x!pt> + %ct_25 = lwe.rmul_plain %extracted_23, %extracted_24 : (!ct_L45, !pt) -> !ct_L45 + %inserted_26 = tensor.insert %ct_25 into %arg9[%arg8] : tensor<6x!ct_L45> + affine.yield %inserted_26 : tensor<6x!ct_L45> + } + %extracted = tensor.extract %3[%c0] : tensor<6x!ct_L45> + %from_elements_22 = tensor.from_elements %extracted : tensor<1x!ct_L45> + return %from_elements_22 : tensor<1x!ct_L45> + } +} diff --git a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_add.mlir b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_add.mlir index 10df2b01c7..ef80b847a1 100644 --- a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_add.mlir +++ b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_add.mlir @@ -9,7 +9,7 @@ !pt = !lattigo.rlwe.plaintext module attributes {bgv.schemeParam = #bgv.scheme_param, scheme.bgv} { // CHECK: func.func @add - func.func @add(%evaluator: !evaluator, %param: !param, %encoder: !encoder, %ct: !ct) -> !ct attributes {mgmt.openfhe_params = #mgmt.openfhe_params} { + func.func @add(%evaluator: !evaluator, %param: !param, %encoder: !encoder, %ct: !ct) -> !ct { // CHECK-COUNT-3: lattigo.bgv.add // CHECK-NOT: lattigo.bgv.add_new %ct_0 = lattigo.bgv.add_new %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct diff --git a/tests/Dialect/Openfhe/IR/ops.mlir b/tests/Dialect/Openfhe/IR/ops.mlir index c7a9229a4a..48ca46fe51 100644 --- a/tests/Dialect/Openfhe/IR/ops.mlir +++ b/tests/Dialect/Openfhe/IR/ops.mlir @@ -2,35 +2,12 @@ // This simply tests for syntax. -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> - -!rns_L0_ = !rns.rns - -#ring_Z65537_i64_1_x1024_ = #polynomial.ring> -#ring_rns_L0_1_x1024_ = #polynomial.ring> - -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#inverse_canonical_encoding = #lwe.inverse_canonical_encoding -#key = #lwe.key<> - -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> - -#plaintext_space = #lwe.plaintext_space -#plaintext_space_f16 = #lwe.plaintext_space - -!pt = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> -!ptf16 = !lwe.lwe_plaintext, plaintext_space = #plaintext_space_f16> - -#ciphertext_space_L0_ = #lwe.ciphertext_space -#ciphertext_space_L0_D3 = #lwe.ciphertext_space - !pk = !openfhe.public_key !sk = !openfhe.private_key !ek = !openfhe.eval_key !cc = !openfhe.crypto_context -!ct = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> -!ct_D3 = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_D3, key = #key, modulus_chain = #modulus_chain_L5_C0_> +!pt = !openfhe.plaintext +!ct = !openfhe.ciphertext module { // CHECK: func @test_make_packed_plaintext @@ -40,9 +17,9 @@ module { } // CHECK: func @test_make_ckks_packed_plaintext - func.func @test_make_ckks_packed_plaintext(%cc: !cc, %arg0 : tensor<32xf16>) -> !ptf16 { - %pt = openfhe.make_ckks_packed_plaintext %cc, %arg0 : (!cc, tensor<32xf16>) -> !ptf16 - return %pt : !ptf16 + func.func @test_make_ckks_packed_plaintext(%cc: !cc, %arg0 : tensor<32xf16>) -> !pt { + %pt = openfhe.make_ckks_packed_plaintext %cc, %arg0 : (!cc, tensor<32xf16>) -> !pt + return %pt : !pt } // CHECK: func @test_encrypt @@ -57,13 +34,6 @@ module { return } - // CHECK: func @test_encode - func.func @test_encode(%arg0: tensor<32xi3>, %pt : !pt, %pk: !pk) { - %0 = arith.extsi %arg0 : tensor<32xi3> to tensor<32xi64> - %out = lwe.rlwe_encode %0 {encoding=#full_crt_packing_encoding, ring=#ring_Z65537_i64_1_x1024_} : tensor<32xi64> -> !pt - return - } - // CHECK: func @test_negate func.func @test_negate(%cc : !cc, %pt : !pt, %pk: !pk) { %ct = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct @@ -78,20 +48,6 @@ module { %out = openfhe.add %cc, %c1, %c2: (!cc, !ct, !ct) -> !ct return } - // CHECK: func @test_inplace_add - func.func @test_inplace_add(%cc: !cc, %pt : !pt, %pk : !pk) { - %c1 = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct - %c2 = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct - openfhe.add_inplace %cc, %c1, %c2: (!cc, !ct, !ct) -> () - return - } - // CHECK: func @test_inplace_sub - func.func @test_inplace_sub(%cc: !cc, %pt : !pt, %pk : !pk) { - %c1 = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct - %c2 = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct - openfhe.sub_inplace %cc, %c1, %c2: (!cc, !ct, !ct) -> () - return - } // CHECK: func @test_sub func.func @test_sub(%cc : !cc, %pt : !pt, %pk: !pk) { @@ -127,7 +83,7 @@ module { func.func @test_mul_no_relin(%cc : !cc, %pt : !pt, %pk: !pk) { %c1 = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct %c2 = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct - %out = openfhe.mul_no_relin %cc, %c1, %c2: (!cc, !ct, !ct) -> !ct_D3 + %out = openfhe.mul_no_relin %cc, %c1, %c2: (!cc, !ct, !ct) -> !ct return } diff --git a/tests/Dialect/Openfhe/Transforms/configure_crypto_context.mlir b/tests/Dialect/Openfhe/Transforms/configure_crypto_context.mlir index 496e526db5..96a2e6f1cf 100644 --- a/tests/Dialect/Openfhe/Transforms/configure_crypto_context.mlir +++ b/tests/Dialect/Openfhe/Transforms/configure_crypto_context.mlir @@ -1,44 +1,25 @@ // RUN: heir-opt --openfhe-configure-crypto-context=entry-function=simple_sum %s | FileCheck %s -!Z1032955396097_i64_ = !mod_arith.int<1032955396097 : i64> -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> -#modulus_chain_L5_C1_ = #lwe.modulus_chain, current = 1> -!rns_L0_ = !rns.rns -!rns_L1_ = !rns.rns -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#plaintext_space = #lwe.plaintext_space -#ring_rns_L0_1_x32_ = #polynomial.ring> -#ring_rns_L1_1_x32_ = #polynomial.ring> -!pt = !lwe.lwe_plaintext>, plaintext_space = #plaintext_space> -!pt1 = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> -#ciphertext_space_L0_ = #lwe.ciphertext_space -#ciphertext_space_L1_ = #lwe.ciphertext_space -!ct_L0_ = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> -!ct_L1_ = !lwe.lwe_ciphertext>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L1_, key = #key, modulus_chain = #modulus_chain_L5_C1_> -!ct_L1_1 = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L1_, key = #key, modulus_chain = #modulus_chain_L5_C1_> +!ct = !openfhe.ciphertext +!pt = !openfhe.plaintext -func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !ct_L1_) -> !ct_L0_ { +func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !ct) -> !ct { %cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<32xi64> - %0 = openfhe.rot %arg0, %arg1 {index = 16 : index} : (!openfhe.crypto_context, !ct_L1_) -> !ct_L1_ - %1 = openfhe.add %arg0, %arg1, %0 : (!openfhe.crypto_context, !ct_L1_, !ct_L1_) -> !ct_L1_ - %2 = openfhe.rot %arg0, %1 {index = 8 : index} : (!openfhe.crypto_context, !ct_L1_) -> !ct_L1_ - %3 = openfhe.add %arg0, %1, %2 : (!openfhe.crypto_context, !ct_L1_, !ct_L1_) -> !ct_L1_ - %4 = openfhe.rot %arg0, %3 {index = 4 : index} : (!openfhe.crypto_context, !ct_L1_) -> !ct_L1_ - %5 = openfhe.add %arg0, %3, %4 : (!openfhe.crypto_context, !ct_L1_, !ct_L1_) -> !ct_L1_ - %6 = openfhe.rot %arg0, %5 {index = 2 : index} : (!openfhe.crypto_context, !ct_L1_) -> !ct_L1_ - %7 = openfhe.add %arg0, %5, %6 : (!openfhe.crypto_context, !ct_L1_, !ct_L1_) -> !ct_L1_ - %8 = openfhe.rot %arg0, %7 {index = 1 : index} : (!openfhe.crypto_context, !ct_L1_) -> !ct_L1_ - %9 = openfhe.add %arg0, %7, %8 : (!openfhe.crypto_context, !ct_L1_, !ct_L1_) -> !ct_L1_ + %0 = openfhe.rot %arg0, %arg1 {index = 16 : index} : (!openfhe.crypto_context, !ct) -> !ct + %1 = openfhe.add %arg0, %arg1, %0 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %2 = openfhe.rot %arg0, %1 {index = 8 : index} : (!openfhe.crypto_context, !ct) -> !ct + %3 = openfhe.add %arg0, %1, %2 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %4 = openfhe.rot %arg0, %3 {index = 4 : index} : (!openfhe.crypto_context, !ct) -> !ct + %5 = openfhe.add %arg0, %3, %4 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %6 = openfhe.rot %arg0, %5 {index = 2 : index} : (!openfhe.crypto_context, !ct) -> !ct + %7 = openfhe.add %arg0, %5, %6 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %8 = openfhe.rot %arg0, %7 {index = 1 : index} : (!openfhe.crypto_context, !ct) -> !ct + %9 = openfhe.add %arg0, %7, %8 : (!openfhe.crypto_context, !ct, !ct) -> !ct %10 = openfhe.make_packed_plaintext %arg0, %cst : (!openfhe.crypto_context, tensor<32xi64>) -> !pt - %11 = openfhe.mul_plain %arg0, %9, %10 : (!openfhe.crypto_context, !ct_L1_, !pt) -> !ct_L1_ - %12 = openfhe.rot %arg0, %11 {index = 31 : index} : (!openfhe.crypto_context, !ct_L1_) -> !ct_L1_ - %13 = lwe.reinterpret_application_data %12 : !ct_L1_ to !ct_L1_1 - %14 = openfhe.mod_reduce %arg0, %13 : (!openfhe.crypto_context, !ct_L1_1) -> !ct_L0_ - return %14 : !ct_L0_ + %11 = openfhe.mul_plain %arg0, %9, %10 : (!openfhe.crypto_context, !ct, !pt) -> !ct + %12 = openfhe.rot %arg0, %11 {index = 31 : index} : (!openfhe.crypto_context, !ct) -> !ct + %14 = openfhe.mod_reduce %arg0, %12 : (!openfhe.crypto_context, !ct) -> !ct + return %14 : !ct } // CHECK: @simple_sum diff --git a/tests/Dialect/Openfhe/Transforms/configure_crypto_context_bootstrap.mlir b/tests/Dialect/Openfhe/Transforms/configure_crypto_context_bootstrap.mlir index d193dfb080..1162793962 100644 --- a/tests/Dialect/Openfhe/Transforms/configure_crypto_context_bootstrap.mlir +++ b/tests/Dialect/Openfhe/Transforms/configure_crypto_context_bootstrap.mlir @@ -1,21 +1,10 @@ // RUN: heir-opt --openfhe-configure-crypto-context=entry-function=bootstrap %s | FileCheck %s -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> -!rns_L0_ = !rns.rns -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#plaintext_space = #lwe.plaintext_space -#ring_rns_L0_1_x32_ = #polynomial.ring> -#ciphertext_space_L0_ = #lwe.ciphertext_space -!ct_L0_ = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> +!ct = !openfhe.ciphertext - -func.func @bootstrap(%arg0: !openfhe.crypto_context, %arg1: !ct_L0_) -> !ct_L0_ { - %0 = openfhe.bootstrap %arg0, %arg1 : (!openfhe.crypto_context, !ct_L0_) -> !ct_L0_ - return %0 : !ct_L0_ +func.func @bootstrap(%arg0: !openfhe.crypto_context, %arg1: !ct) -> !ct { + %0 = openfhe.bootstrap %arg0, %arg1 : (!openfhe.crypto_context, !ct) -> !ct + return %0 : !ct } // CHECK: @bootstrap diff --git a/tests/Dialect/Openfhe/Transforms/configure_crypto_context_detect.mlir b/tests/Dialect/Openfhe/Transforms/configure_crypto_context_detect.mlir index ab2b21bf98..75f20d36ef 100644 --- a/tests/Dialect/Openfhe/Transforms/configure_crypto_context_detect.mlir +++ b/tests/Dialect/Openfhe/Transforms/configure_crypto_context_detect.mlir @@ -2,27 +2,11 @@ // Test whether detection works for one main function -!Z1032955396097_i64_ = !mod_arith.int<1032955396097 : i64> -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> -#modulus_chain_L5_C1_ = #lwe.modulus_chain, current = 1> -!rns_L0_ = !rns.rns -!rns_L1_ = !rns.rns -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#plaintext_space = #lwe.plaintext_space -#ring_rns_L0_1_x32_ = #polynomial.ring> -#ring_rns_L1_1_x32_ = #polynomial.ring> -#ciphertext_space_L0_ = #lwe.ciphertext_space -#ciphertext_space_L1_ = #lwe.ciphertext_space -!ct_L0_ = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> -!ct_L1_ = !lwe.lwe_ciphertext>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L1_, key = #key, modulus_chain = #modulus_chain_L5_C1_> +!ct = !openfhe.ciphertext -func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !ct_L1_) -> !ct_L0_ { - %14 = openfhe.mod_reduce %arg0, %arg1 : (!openfhe.crypto_context, !ct_L1_) -> !ct_L0_ - return %14 : !ct_L0_ +func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !ct) -> !ct { + %14 = openfhe.mod_reduce %arg0, %arg1 : (!openfhe.crypto_context, !ct) -> !ct + return %14 : !ct } // CHECK: @simple_sum @@ -33,32 +17,16 @@ func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !ct_L1_) -> !ct_L0_ // Test whether called function is skipped -!Z1032955396097_i64_ = !mod_arith.int<1032955396097 : i64> -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> -#modulus_chain_L5_C1_ = #lwe.modulus_chain, current = 1> -!rns_L0_ = !rns.rns -!rns_L1_ = !rns.rns -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#plaintext_space = #lwe.plaintext_space -#ring_rns_L0_1_x32_ = #polynomial.ring> -#ring_rns_L1_1_x32_ = #polynomial.ring> -#ciphertext_space_L0_ = #lwe.ciphertext_space -#ciphertext_space_L1_ = #lwe.ciphertext_space -!ct_L0_ = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> -!ct_L1_ = !lwe.lwe_ciphertext>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L1_, key = #key, modulus_chain = #modulus_chain_L5_C1_> +!ct = !openfhe.ciphertext -func.func @test(%arg0: !openfhe.crypto_context, %arg1: !ct_L0_) -> !ct_L0_ { - return %arg1 : !ct_L0_ +func.func @test(%arg0: !openfhe.crypto_context, %arg1: !ct) -> !ct { + return %arg1 : !ct } -func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !ct_L1_) -> !ct_L0_ { - %0 = openfhe.mod_reduce %arg0, %arg1 : (!openfhe.crypto_context, !ct_L1_) -> !ct_L0_ - %1 = call @test(%arg0, %0) : (!openfhe.crypto_context, !ct_L0_) -> !ct_L0_ - return %1 : !ct_L0_ +func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !ct) -> !ct { + %0 = openfhe.mod_reduce %arg0, %arg1 : (!openfhe.crypto_context, !ct) -> !ct + %1 = call @test(%arg0, %0) : (!openfhe.crypto_context, !ct) -> !ct + return %1 : !ct } // CHECK: @test diff --git a/tests/Dialect/Openfhe/Transforms/eval_add_count.mlir b/tests/Dialect/Openfhe/Transforms/eval_add_count.mlir index bde6e1c502..a99feee1b6 100644 --- a/tests/Dialect/Openfhe/Transforms/eval_add_count.mlir +++ b/tests/Dialect/Openfhe/Transforms/eval_add_count.mlir @@ -1,6 +1,8 @@ // RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-bgv --openfhe-count-add-and-key-switch %s | FileCheck %s -// CHECK: mgmt.openfhe_params = #mgmt.openfhe_params +// CHECK: #mgmt.openfhe_params< +// CHECK-SAME: evalAddCount = 8 +// CHECK-SAME: keySwitchCount = 15 func.func @dot_product(%arg0: tensor<8xi16> {secret.secret}, %arg1: tensor<8xi16> {secret.secret}) -> i16 { %c0 = arith.constant 0 : index %c0_si16 = arith.constant 0 : i16 diff --git a/tests/Dialect/Openfhe/Transforms/precompute_bulk_rotations.mlir b/tests/Dialect/Openfhe/Transforms/precompute_bulk_rotations.mlir index d01a005b89..b4bd5781cc 100644 --- a/tests/Dialect/Openfhe/Transforms/precompute_bulk_rotations.mlir +++ b/tests/Dialect/Openfhe/Transforms/precompute_bulk_rotations.mlir @@ -1,36 +1,22 @@ // RUN: heir-opt --openfhe-fast-rotation-precompute %s | FileCheck %s -!Z1032955396097_i64 = !mod_arith.int<1032955396097 : i64> -!Z1095233372161_i64 = !mod_arith.int<1095233372161 : i64> -!Z65537_i64 = !mod_arith.int<65537 : i64> !cc = !openfhe.crypto_context -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> -#modulus_chain_L5_C0 = #lwe.modulus_chain, current = 0> -#modulus_chain_L5_C1 = #lwe.modulus_chain, current = 1> -!rns_L0 = !rns.rns -!rns_L1 = !rns.rns -#ring_Z65537_i64_1_x32 = #polynomial.ring> -#ring_rns_L0_1_x32 = #polynomial.ring> -#ring_rns_L1_1_x32 = #polynomial.ring> -#ciphertext_space_L0 = #lwe.ciphertext_space -#ciphertext_space_L1 = #lwe.ciphertext_space -!ct_L1 = !lwe.lwe_ciphertext>, plaintext_space = , ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L5_C1> +!ct = !openfhe.ciphertext module { - func.func @simple_sum(%cc: !cc, %ct: !ct_L1) -> !ct_L1 { + func.func @simple_sum(%cc: !cc, %ct: !ct) -> !ct { // CHECK: openfhe.fast_rotation_precompute // CHECK-COUNT-4: openfhe.fast_rotation // CHECK-NOT: openfhe.rot %cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<32xi64> - %ct_0 = openfhe.rot %cc, %ct {index = 16 : index} : (!cc, !ct_L1) -> !ct_L1 - %ct_1 = openfhe.add %cc, %ct, %ct_0 : (!cc, !ct_L1, !ct_L1) -> !ct_L1 - %ct_2 = openfhe.rot %cc, %ct {index = 8 : index} : (!cc, !ct_L1) -> !ct_L1 - %ct_3 = openfhe.add %cc, %ct_1, %ct_2 : (!cc, !ct_L1, !ct_L1) -> !ct_L1 - %ct_4 = openfhe.rot %cc, %ct {index = 5 : index} : (!cc, !ct_L1) -> !ct_L1 - %ct_5 = openfhe.add %cc, %ct_3, %ct_4 : (!cc, !ct_L1, !ct_L1) -> !ct_L1 - %ct_6 = openfhe.rot %cc, %ct {index = 12 : index} : (!cc, !ct_L1) -> !ct_L1 - %ct_7 = openfhe.add %cc, %ct_5, %ct_6 : (!cc, !ct_L1, !ct_L1) -> !ct_L1 - return %ct_7 : !ct_L1 + %ct_0 = openfhe.rot %cc, %ct {index = 16 : index} : (!cc, !ct) -> !ct + %ct_1 = openfhe.add %cc, %ct, %ct_0 : (!cc, !ct, !ct) -> !ct + %ct_2 = openfhe.rot %cc, %ct {index = 8 : index} : (!cc, !ct) -> !ct + %ct_3 = openfhe.add %cc, %ct_1, %ct_2 : (!cc, !ct, !ct) -> !ct + %ct_4 = openfhe.rot %cc, %ct {index = 5 : index} : (!cc, !ct) -> !ct + %ct_5 = openfhe.add %cc, %ct_3, %ct_4 : (!cc, !ct, !ct) -> !ct + %ct_6 = openfhe.rot %cc, %ct {index = 12 : index} : (!cc, !ct) -> !ct + %ct_7 = openfhe.add %cc, %ct_5, %ct_6 : (!cc, !ct, !ct) -> !ct + return %ct_7 : !ct } } diff --git a/tests/Emitter/Lattigo/emit_lattigo.mlir b/tests/Emitter/Lattigo/emit_lattigo.mlir index c880e7f955..70471e06d3 100644 --- a/tests/Emitter/Lattigo/emit_lattigo.mlir +++ b/tests/Emitter/Lattigo/emit_lattigo.mlir @@ -181,7 +181,7 @@ module attributes {scheme.bgv} { module attributes {scheme.bgv} { func.func private @__heir_debug_0(!lattigo.bgv.evaluator, !lattigo.bgv.parameter, !lattigo.bgv.encoder, !lattigo.rlwe.decryptor, !lattigo.rlwe.ciphertext) - func.func @dot_product(%evaluator: !lattigo.bgv.evaluator, %param: !lattigo.bgv.parameter, %encoder: !lattigo.bgv.encoder, %decryptor: !lattigo.rlwe.decryptor, %ct: !lattigo.rlwe.ciphertext, %ct_0: !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext attributes {mgmt.openfhe_params = #mgmt.openfhe_params} { + func.func @dot_product(%evaluator: !lattigo.bgv.evaluator, %param: !lattigo.bgv.parameter, %encoder: !lattigo.bgv.encoder, %decryptor: !lattigo.rlwe.decryptor, %ct: !lattigo.rlwe.ciphertext, %ct_0: !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext { call @__heir_debug_0(%evaluator, %param, %encoder, %decryptor, %ct) {bound = "50", random = 3, complex = {test = 1.2}, secret.secret} : (!lattigo.bgv.evaluator, !lattigo.bgv.parameter, !lattigo.bgv.encoder, !lattigo.rlwe.decryptor, !lattigo.rlwe.ciphertext) -> () return %ct : !lattigo.rlwe.ciphertext } diff --git a/tests/Emitter/Openfhe/emit_bool.mlir b/tests/Emitter/Openfhe/emit_bool.mlir index 4af7fb546b..4785cbc6f3 100644 --- a/tests/Emitter/Openfhe/emit_bool.mlir +++ b/tests/Emitter/Openfhe/emit_bool.mlir @@ -1,41 +1,20 @@ // RUN: heir-translate %s --emit-openfhe-pke | FileCheck %s -!Z65537_i64 = !mod_arith.int<65537 : i64> -!Z67239937_i64 = !mod_arith.int<67239937 : i64> -!Z8796093202433_i64 = !mod_arith.int<8796093202433 : i64> !cc = !openfhe.crypto_context -!params = !openfhe.cc_params +!pt = !openfhe.plaintext !pk = !openfhe.public_key -!sk = !openfhe.private_key -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> -#modulus_chain_L1_C0 = #lwe.modulus_chain, current = 0> -#modulus_chain_L1_C1 = #lwe.modulus_chain, current = 1> -!rns_L0 = !rns.rns -!rns_L1 = !rns.rns -#ring_Z65537_i64_1_x1024 = #polynomial.ring> -#ring_rns_L0_1_x1024 = #polynomial.ring> -#ring_rns_L1_1_x1024 = #polynomial.ring> -!pt = !lwe.lwe_plaintext, plaintext_space = > -!pt1 = !lwe.lwe_plaintext, plaintext_space = > -#ciphertext_space_L0 = #lwe.ciphertext_space -#ciphertext_space_L1 = #lwe.ciphertext_space -#ciphertext_space_L1_D3 = #lwe.ciphertext_space -!ct_L0 = !lwe.lwe_ciphertext, plaintext_space = , ciphertext_space = #ciphertext_space_L0, key = #key, modulus_chain = #modulus_chain_L1_C0> -!ct_L1 = !lwe.lwe_ciphertext, plaintext_space = , ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L1_C1> -!ct_L1_1 = !lwe.lwe_ciphertext, plaintext_space = , ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L1_C1> -!ct_L1_D3 = !lwe.lwe_ciphertext, plaintext_space = , ciphertext_space = #ciphertext_space_L1_D3, key = #key, modulus_chain = #modulus_chain_L1_C1> +!ct = !openfhe.ciphertext module attributes {scheme.bgv} { // CHECK: CiphertextT emit_bool // CHECK-SAME: CryptoContextT [[cc:.*]], bool [[v0:.*]], PublicKeyT [[pk:.*]]) { - func.func @emit_bool(%cc: !cc, %arg0: i1, %pk: !pk) -> !ct_L1_1 { + func.func @emit_bool(%cc: !cc, %arg0: i1, %pk: !pk) -> !ct { // CHECK: std::vector [[v1:.*]](1024, [[v0]]); // CHECK-NEXT: std::vector [[v2:.*]](std::begin([[v1]]), std::end([[v1]])) %splat = tensor.splat %arg0 : tensor<1024xi1> %0 = arith.extui %splat : tensor<1024xi1> to tensor<1024xi64> - %pt = openfhe.make_packed_plaintext %cc, %0 : (!cc, tensor<1024xi64>) -> !pt1 - %ct = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt1, !pk) -> !ct_L1_1 - return %ct : !ct_L1_1 + %pt = openfhe.make_packed_plaintext %cc, %0 : (!cc, tensor<1024xi64>) -> !pt + %ct = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct + return %ct : !ct } } diff --git a/tests/Emitter/Openfhe/emit_loops.mlir b/tests/Emitter/Openfhe/emit_loops.mlir index a4ae14d764..c608c16331 100644 --- a/tests/Emitter/Openfhe/emit_loops.mlir +++ b/tests/Emitter/Openfhe/emit_loops.mlir @@ -1,22 +1,7 @@ // RUN: heir-translate %s --emit-openfhe-pke --split-input-file | FileCheck %s -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> - -!rns_L0_ = !rns.rns -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#ring_rns_L0_1_x32_ = #polynomial.ring> - -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> - -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> - -#plaintext_space = #lwe.plaintext_space - -#ciphertext_space_L0_ = #lwe.ciphertext_space - -!ct_L0_ = !lwe.lwe_ciphertext>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> +!cc = !openfhe.crypto_context +!ct = !openfhe.ciphertext module attributes {scheme.ckks} { // CHECK: test_affine_for @@ -26,11 +11,11 @@ module attributes {scheme.ckks} { // CHECK: [[ct1]] = [[cc]]->EvalRotate([[ct1]], 1); // CHECK: } // CHECK: return [[ct1]]; - func.func @test_affine_for(%cc: !openfhe.crypto_context, %ct: !ct_L0_) -> !ct_L0_ { - %1 = affine.for %arg0 = 1 to 2 iter_args(%arg1 = %ct) -> (!ct_L0_) { - %ct_12 = openfhe.rot %cc, %arg1 {index = 1 : index} : (!openfhe.crypto_context, !ct_L0_) -> !ct_L0_ - affine.yield %ct_12 : !ct_L0_ + func.func @test_affine_for(%cc: !openfhe.crypto_context, %ct: !ct) -> !ct { + %1 = affine.for %arg0 = 1 to 2 iter_args(%arg1 = %ct) -> (!ct) { + %ct_12 = openfhe.rot %cc, %arg1 {index = 1 : index} : (!openfhe.crypto_context, !ct) -> !ct + affine.yield %ct_12 : !ct } - return %1 : !ct_L0_ + return %1 : !ct } } diff --git a/tests/Emitter/Openfhe/emit_openfhe_pke.mlir b/tests/Emitter/Openfhe/emit_openfhe_pke.mlir index 8c7138eb32..cc3cbad9e4 100644 --- a/tests/Emitter/Openfhe/emit_openfhe_pke.mlir +++ b/tests/Emitter/Openfhe/emit_openfhe_pke.mlir @@ -1,27 +1,9 @@ // RUN: heir-translate %s --emit-openfhe-pke --split-input-file | FileCheck %s -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> - -!rns_L0_ = !rns.rns - -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#ring_rns_L0_1_x32_ = #polynomial.ring> - -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> - -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> - -#plaintext_space = #lwe.plaintext_space - -#ciphertext_space_L0_ = #lwe.ciphertext_space - !cc = !openfhe.crypto_context !ek = !openfhe.eval_key - -!pt = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> -!ct = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> +!pt = !openfhe.plaintext +!ct = !openfhe.ciphertext // CHECK: CiphertextT test_basic_emitter( // CHECK-SAME: CryptoContextT [[CC:[^,]*]], @@ -69,30 +51,10 @@ module attributes {scheme.bgv} { // ----- -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> - -!rns_L0_ = !rns.rns - -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#ring_rns_L0_1_x32_ = #polynomial.ring> - -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> - -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> - -#plaintext_space = #lwe.plaintext_space - -#ciphertext_space_L0_ = #lwe.ciphertext_space - !cc = !openfhe.crypto_context !ek = !openfhe.eval_key - -!tensor_pt_ty = !lwe.lwe_plaintext>, plaintext_space = #plaintext_space> -!scalar_pt_ty = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> -!tensor_ct_ty = !lwe.lwe_ciphertext>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> -!scalar_ct_ty = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> +!pt = !openfhe.plaintext +!ct = !openfhe.ciphertext // CHECK: simple_sum( // CHECK-COUNT-6: EvalRotate @@ -105,32 +67,31 @@ module attributes {scheme.bgv} { // CHECK: int16_t // CHECK-SAME: [0] module attributes {scheme.ckks} { - func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !tensor_ct_ty) -> !scalar_ct_ty { - %1 = openfhe.rot %arg0, %arg1 { index = 16 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %2 = openfhe.add %arg0, %arg1, %1 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty - %4 = openfhe.rot %arg0, %2 { index = 8 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %5 = openfhe.add %arg0, %2, %4 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty - %7 = openfhe.rot %arg0, %5 { index = 4 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %8 = openfhe.add %arg0, %5, %7 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty - %10 = openfhe.rot %arg0, %8 { index = 2 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %11 = openfhe.add %arg0, %8, %10 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty - %13 = openfhe.rot %arg0, %11 { index = 1 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %14 = openfhe.add %arg0, %11, %13 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty + func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !ct) -> !ct { + %1 = openfhe.rot %arg0, %arg1 { index = 16 } : (!openfhe.crypto_context, !ct) -> !ct + %2 = openfhe.add %arg0, %arg1, %1 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %4 = openfhe.rot %arg0, %2 { index = 8 } : (!openfhe.crypto_context, !ct) -> !ct + %5 = openfhe.add %arg0, %2, %4 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %7 = openfhe.rot %arg0, %5 { index = 4 } : (!openfhe.crypto_context, !ct) -> !ct + %8 = openfhe.add %arg0, %5, %7 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %10 = openfhe.rot %arg0, %8 { index = 2 } : (!openfhe.crypto_context, !ct) -> !ct + %11 = openfhe.add %arg0, %8, %10 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %13 = openfhe.rot %arg0, %11 { index = 1 } : (!openfhe.crypto_context, !ct) -> !ct + %14 = openfhe.add %arg0, %11, %13 : (!openfhe.crypto_context, !ct, !ct) -> !ct %cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<32xi16> - %15 = openfhe.make_packed_plaintext %arg0, %cst : (!openfhe.crypto_context, tensor<32xi16>) -> !tensor_pt_ty - %16 = openfhe.mul_plain %arg0, %14, %15 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_pt_ty) -> !tensor_ct_ty - %18 = openfhe.rot %arg0, %16 { index = 31 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %19 = lwe.reinterpret_application_data %18 : !tensor_ct_ty to !scalar_ct_ty - return %19 : !scalar_ct_ty + %15 = openfhe.make_packed_plaintext %arg0, %cst : (!openfhe.crypto_context, tensor<32xi16>) -> !pt + %16 = openfhe.mul_plain %arg0, %14, %15 : (!openfhe.crypto_context, !ct, !pt) -> !ct + %18 = openfhe.rot %arg0, %16 { index = 31 } : (!openfhe.crypto_context, !ct) -> !ct + return %18 : !ct } - func.func @simple_sum__encrypt(%arg0: !openfhe.crypto_context, %arg1: tensor<32xi16>, %arg2: !openfhe.public_key) -> !tensor_ct_ty { - %0 = openfhe.make_packed_plaintext %arg0, %arg1 : (!openfhe.crypto_context, tensor<32xi16>) -> !tensor_pt_ty - %1 = openfhe.encrypt %arg0, %0, %arg2 : (!openfhe.crypto_context, !tensor_pt_ty, !openfhe.public_key) -> !tensor_ct_ty - return %1 : !tensor_ct_ty + func.func @simple_sum__encrypt(%arg0: !openfhe.crypto_context, %arg1: tensor<32xi16>, %arg2: !openfhe.public_key) -> !ct { + %0 = openfhe.make_packed_plaintext %arg0, %arg1 : (!openfhe.crypto_context, tensor<32xi16>) -> !pt + %1 = openfhe.encrypt %arg0, %0, %arg2 : (!openfhe.crypto_context, !pt, !openfhe.public_key) -> !ct + return %1 : !ct } - func.func @simple_sum__decrypt(%arg0: !openfhe.crypto_context, %arg1: !scalar_ct_ty, %arg2: !openfhe.private_key) -> i16 { - %0 = openfhe.decrypt %arg0, %arg1, %arg2 : (!openfhe.crypto_context, !scalar_ct_ty, !openfhe.private_key) -> !scalar_pt_ty - %1 = lwe.rlwe_decode %0 {encoding = #full_crt_packing_encoding, ring = #ring_Z65537_i64_1_x32_} : !scalar_pt_ty -> i16 + func.func @simple_sum__decrypt(%arg0: !openfhe.crypto_context, %arg1: !ct, %arg2: !openfhe.private_key) -> i16 { + %0 = openfhe.decrypt %arg0, %arg1, %arg2 : (!openfhe.crypto_context, !ct, !openfhe.private_key) -> !pt + %1 = openfhe.decode %0 : !pt -> i16 return %1 : i16 } // CHECK: CiphertextT test_sub_plain( @@ -140,9 +101,9 @@ module attributes {scheme.ckks} { // CHECK-NEXT: const auto& [[v0:.*]] = [[CC]]->EvalSub([[ARG2]], [[ARG1]]); // CHECK-NEXT: return [[v0]]; // CHECK-NEXT: } - func.func @test_sub_plain(%cc: !openfhe.crypto_context, %pt :!tensor_pt_ty, %ct : !tensor_ct_ty) -> !tensor_ct_ty { - %0 = openfhe.sub_plain %cc, %ct, %pt: (!openfhe.crypto_context, !tensor_ct_ty, !tensor_pt_ty) -> !tensor_ct_ty - return %0 : !tensor_ct_ty + func.func @test_sub_plain(%cc: !openfhe.crypto_context, %pt :!pt, %ct : !ct) -> !ct { + %0 = openfhe.sub_plain %cc, %ct, %pt: (!openfhe.crypto_context, !ct, !pt) -> !ct + return %0 : !ct } } @@ -182,18 +143,10 @@ module attributes {scheme.ckks} { // ----- -!Z2147565569_i64_ = !mod_arith.int<2147565569 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> -#modulus_chain_L0_C0_ = #lwe.modulus_chain, current = 0> -!rns_L0_ = !rns.rns -#ring_Z65537_i64_1_x8_ = #polynomial.ring> -#plaintext_space = #lwe.plaintext_space -#ring_rns_L0_1_x8_ = #polynomial.ring> -!pt = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> -#ciphertext_space_L0_ = #lwe.ciphertext_space -!ct_L0_ = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L0_C0_> +!cc = !openfhe.crypto_context +!ek = !openfhe.eval_key +!pt = !openfhe.plaintext +!ct = !openfhe.ciphertext // CHECK: __heir_debug(CryptoContextT, PrivateKeyT, CiphertextT, const std::map&) // CHECK: ["bound"] = "50" @@ -204,40 +157,27 @@ module attributes {scheme.ckks} { // CHECK: ["asm.result_ssa_format"] module attributes {scheme.bgv} { - func.func private @__heir_debug_0(!openfhe.crypto_context, !openfhe.private_key, !ct_L0_) - func.func @add(%cc: !openfhe.crypto_context, %sk: !openfhe.private_key, %ct: !ct_L0_) -> !ct_L0_ { - call @__heir_debug_0(%cc, %sk, %ct) {bound = "50", random = 3, complex = {test = 1.2}, secret.secret} : (!openfhe.crypto_context, !openfhe.private_key, !ct_L0_) -> () - return %ct : !ct_L0_ + func.func private @__heir_debug_0(!openfhe.crypto_context, !openfhe.private_key, !ct) + func.func @add(%cc: !openfhe.crypto_context, %sk: !openfhe.private_key, %ct: !ct) -> !ct { + call @__heir_debug_0(%cc, %sk, %ct) {bound = "50", random = 3, complex = {test = 1.2}, secret.secret} : (!openfhe.crypto_context, !openfhe.private_key, !ct) -> () + return %ct : !ct } } // ----- -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> - -!rns_L0_ = !rns.rns -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#ring_rns_L0_1_x32_ = #polynomial.ring> - -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> - -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> - -#plaintext_space = #lwe.plaintext_space - -#ciphertext_space_L0_ = #lwe.ciphertext_space - -!ct_L0_ = !lwe.lwe_ciphertext>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> +!cc = !openfhe.crypto_context +!ek = !openfhe.eval_key +!pt = !openfhe.plaintext +!ct = !openfhe.ciphertext module attributes {scheme.ckks} { // CHECK: test_func_call // CHECK: const auto& [[v0:.*]] = callee_secret - func.func private @callee_secret(!openfhe.crypto_context, !ct_L0_) -> !ct_L0_ - func.func @test_func_call(%cc: !openfhe.crypto_context, %arg0: !ct_L0_) -> !ct_L0_ { - %1 = call @callee_secret(%cc, %arg0) : (!openfhe.crypto_context, !ct_L0_) -> !ct_L0_ - return %1 : !ct_L0_ + func.func private @callee_secret(!openfhe.crypto_context, !ct) -> !ct + func.func @test_func_call(%cc: !openfhe.crypto_context, %arg0: !ct) -> !ct { + %1 = call @callee_secret(%cc, %arg0) : (!openfhe.crypto_context, !ct) -> !ct + return %1 : !ct } } @@ -412,18 +352,11 @@ module attributes {scheme.ckks} { // ----- -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> -!rns_L0_ = !rns.rns -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#ring_rns_L0_1_x32_ = #polynomial.ring> -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> -#plaintext_space = #lwe.plaintext_space -#ciphertext_space_L0_ = #lwe.ciphertext_space !cc = !openfhe.crypto_context -!ct = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> +!ek = !openfhe.eval_key +!pt = !openfhe.plaintext +!ct = !openfhe.ciphertext + // CHECK: CiphertextT test_fast_rot( // CHECK-SAME: CryptoContextT [[CC:[^,]*]], diff --git a/tests/Emitter/Openfhe/emit_plaintext.mlir b/tests/Emitter/Openfhe/emit_plaintext.mlir index 6fc6ce503a..2dfdaeea3c 100644 --- a/tests/Emitter/Openfhe/emit_plaintext.mlir +++ b/tests/Emitter/Openfhe/emit_plaintext.mlir @@ -1,69 +1,51 @@ // Regression test for https://github.com/google/heir/issues/1621 // RUN: heir-translate %s --emit-openfhe-pke | FileCheck %s -!Z1073750017_i64 = !mod_arith.int<1073750017 : i64> -!Z65537_i64 = !mod_arith.int<65537 : i64> -!Z67239937_i64 = !mod_arith.int<67239937 : i64> -!cc = !openfhe.crypto_context !params = !openfhe.cc_params +!cc = !openfhe.crypto_context +!pt = !openfhe.plaintext !pk = !openfhe.public_key !sk = !openfhe.private_key -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> -#modulus_chain_L1_C0 = #lwe.modulus_chain, current = 0> -#modulus_chain_L1_C1 = #lwe.modulus_chain, current = 1> -!rns_L0 = !rns.rns -!rns_L1 = !rns.rns -#ring_Z65537_i64_1_x1024 = #polynomial.ring> -#ring_rns_L0_1_x1024 = #polynomial.ring> -#ring_rns_L1_1_x1024 = #polynomial.ring> -!pt = !lwe.lwe_plaintext, plaintext_space = > -!pt1 = !lwe.lwe_plaintext, plaintext_space = > -#ciphertext_space_L0 = #lwe.ciphertext_space -#ciphertext_space_L1 = #lwe.ciphertext_space -!ct_L0 = !lwe.lwe_ciphertext, plaintext_space = , ciphertext_space = #ciphertext_space_L0, key = #key, modulus_chain = #modulus_chain_L1_C0> -!ct_L1 = !lwe.lwe_ciphertext, plaintext_space = , ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L1_C1> -!ct_L1_1 = !lwe.lwe_ciphertext, plaintext_space = , ciphertext_space = #ciphertext_space_L1, key = #key, modulus_chain = #modulus_chain_L1_C1> +!ct = !openfhe.ciphertext + module attributes {scheme.bgv} { // CHECK: CiphertextT cond // CHECK-SAME: CryptoContextT [[cc:.*]], int64_t [[v0:.*]], int64_t [[v1:.*]], CiphertextT [[ct:.*]] - func.func @cond(%cc: !cc, %arg0: i64, %arg1: i64, %ct: !ct_L1) -> !ct_L0 { + func.func @cond(%cc: !cc, %arg0: i64, %arg1: i64, %ct: !ct) -> !ct { // CHECK: std::vector [[v2:.*]](1024, 1); // CHECK-NEXT: auto [[pt:.*]]_filled_n = [[cc]]->GetCryptoParameters()->GetElementParams()->GetRingDimension() / 2; // CHECK-NEXT: auto [[pt]]_filled = [[v2]] // CHECK: auto [[pt]] = [[cc]]->MakePackedPlaintext %cst = arith.constant dense<1> : tensor<1024xi64> %pt = openfhe.make_packed_plaintext %cc, %cst : (!cc, tensor<1024xi64>) -> !pt - %ct_0 = openfhe.negate %cc, %ct : (!cc, !ct_L1) -> !ct_L1 - %ct_1 = openfhe.add_plain %cc, %ct_0, %pt : (!cc, !ct_L1, !pt) -> !ct_L1 - %ct_2 = lwe.reinterpret_application_data %ct : !ct_L1 to !ct_L1_1 + %ct_0 = openfhe.negate %cc, %ct : (!cc, !ct) -> !ct + %ct_1 = openfhe.add_plain %cc, %ct_0, %pt : (!cc, !ct, !pt) -> !ct %splat = tensor.splat %arg0 : tensor<1024xi64> // CHECK: [[cc]]->MakePackedPlaintext - %pt_3 = openfhe.make_packed_plaintext %cc, %splat : (!cc, tensor<1024xi64>) -> !pt1 - %ct_4 = openfhe.mul_plain %cc, %ct_2, %pt_3 : (!cc, !ct_L1_1, !pt1) -> !ct_L1_1 - %ct_5 = lwe.reinterpret_application_data %ct_1 : !ct_L1 to !ct_L1_1 + %pt_3 = openfhe.make_packed_plaintext %cc, %splat : (!cc, tensor<1024xi64>) -> !pt + %ct_4 = openfhe.mul_plain %cc, %ct_1, %pt_3 : (!cc, !ct, !pt) -> !ct %splat_6 = tensor.splat %arg1 : tensor<1024xi64> // CHECK: [[cc]]->MakePackedPlaintext - %pt_7 = openfhe.make_packed_plaintext %cc, %splat_6 : (!cc, tensor<1024xi64>) -> !pt1 - %ct_8 = openfhe.mul_plain %cc, %ct_5, %pt_7 : (!cc, !ct_L1_1, !pt1) -> !ct_L1_1 - %ct_9 = openfhe.add %cc, %ct_4, %ct_8 : (!cc, !ct_L1_1, !ct_L1_1) -> !ct_L1_1 + %pt_7 = openfhe.make_packed_plaintext %cc, %splat_6 : (!cc, tensor<1024xi64>) -> !pt + %ct_8 = openfhe.mul_plain %cc, %ct_4, %pt_7 : (!cc, !ct, !pt) -> !ct + %ct_9 = openfhe.add %cc, %ct_4, %ct_8 : (!cc, !ct, !ct) -> !ct // CHECK-NOT: [[pt]]_filled_n - %pt_10 = openfhe.make_packed_plaintext %cc, %cst : (!cc, tensor<1024xi64>) -> !pt1 - %ct_11 = openfhe.add_plain %cc, %ct_9, %pt_10 : (!cc, !ct_L1_1, !pt1) -> !ct_L1_1 - %ct_12 = openfhe.mod_reduce %cc, %ct_11 : (!cc, !ct_L1_1) -> !ct_L0 + %pt_10 = openfhe.make_packed_plaintext %cc, %cst : (!cc, tensor<1024xi64>) -> !pt + %ct_11 = openfhe.add_plain %cc, %ct_9, %pt_10 : (!cc, !ct, !pt) -> !ct + %ct_12 = openfhe.mod_reduce %cc, %ct_11 : (!cc, !ct) -> !ct // CHECK: return - return %ct_12 : !ct_L0 + return %ct_12 : !ct } - func.func @cond__encrypt__arg2(%cc: !cc, %arg0: i1, %pk: !pk) -> !ct_L1 { + func.func @cond__encrypt__arg2(%cc: !cc, %arg0: i1, %pk: !pk) -> !ct { %splat = tensor.splat %arg0 : tensor<1024xi1> %0 = arith.extui %splat : tensor<1024xi1> to tensor<1024xi64> %pt = openfhe.make_packed_plaintext %cc, %0 : (!cc, tensor<1024xi64>) -> !pt - %ct = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct_L1 - return %ct : !ct_L1 + %ct = openfhe.encrypt %cc, %pt, %pk : (!cc, !pt, !pk) -> !ct + return %ct : !ct } - func.func @cond__decrypt__result0(%cc: !cc, %ct: !ct_L0, %sk: !sk) -> i64 { - %pt = openfhe.decrypt %cc, %ct, %sk : (!cc, !ct_L0, !sk) -> !pt1 - %0 = lwe.rlwe_decode %pt {encoding = #full_crt_packing_encoding, ring = #ring_Z65537_i64_1_x1024} : !pt1 -> i64 + func.func @cond__decrypt__result0(%cc: !cc, %ct: !ct, %sk: !sk) -> i64 { + %pt = openfhe.decrypt %cc, %ct, %sk : (!cc, !ct, !sk) -> !pt + %0 = openfhe.decode %pt : !pt -> i64 return %0 : i64 } func.func @cond__generate_crypto_context() -> !cc { diff --git a/tests/Emitter/Openfhe/emit_pybind.mlir b/tests/Emitter/Openfhe/emit_pybind.mlir index 421f68ce60..c79a6dcd64 100644 --- a/tests/Emitter/Openfhe/emit_pybind.mlir +++ b/tests/Emitter/Openfhe/emit_pybind.mlir @@ -33,56 +33,35 @@ // CHECK: m.def("simple_sum__decrypt", &simple_sum__decrypt, py::call_guard()); // CHECK: } -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> - -!rns_L0_ = !rns.rns - -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#ring_rns_L0_1_x32_ = #polynomial.ring> - -#full_crt_packing_encoding = #lwe.full_crt_packing_encoding -#key = #lwe.key<> - -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> - -#plaintext_space = #lwe.plaintext_space - -#ciphertext_space_L0_ = #lwe.ciphertext_space - !cc = !openfhe.crypto_context !ek = !openfhe.eval_key +!ct = !openfhe.ciphertext +!pt = !openfhe.plaintext -!tensor_pt_ty = !lwe.lwe_plaintext>, plaintext_space = #plaintext_space> -!scalar_pt_ty = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> -!tensor_ct_ty = !lwe.lwe_ciphertext>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> -!scalar_ct_ty = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> - -func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !tensor_ct_ty) -> !scalar_ct_ty { - %1 = openfhe.rot %arg0, %arg1 { index = 16 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %2 = openfhe.add %arg0, %arg1, %1 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty - %4 = openfhe.rot %arg0, %2 { index = 8 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %5 = openfhe.add %arg0, %2, %4 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty - %7 = openfhe.rot %arg0, %5 { index = 4 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %8 = openfhe.add %arg0, %5, %7 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty - %10 = openfhe.rot %arg0, %8 { index = 2 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %11 = openfhe.add %arg0, %8, %10 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty - %13 = openfhe.rot %arg0, %11 { index = 1 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %14 = openfhe.add %arg0, %11, %13 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_ct_ty) -> !tensor_ct_ty +func.func @simple_sum(%arg0: !openfhe.crypto_context, %arg1: !ct) -> !ct { + %1 = openfhe.rot %arg0, %arg1 { index = 16 } : (!openfhe.crypto_context, !ct) -> !ct + %2 = openfhe.add %arg0, %arg1, %1 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %4 = openfhe.rot %arg0, %2 { index = 8 } : (!openfhe.crypto_context, !ct) -> !ct + %5 = openfhe.add %arg0, %2, %4 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %7 = openfhe.rot %arg0, %5 { index = 4 } : (!openfhe.crypto_context, !ct) -> !ct + %8 = openfhe.add %arg0, %5, %7 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %10 = openfhe.rot %arg0, %8 { index = 2 } : (!openfhe.crypto_context, !ct) -> !ct + %11 = openfhe.add %arg0, %8, %10 : (!openfhe.crypto_context, !ct, !ct) -> !ct + %13 = openfhe.rot %arg0, %11 { index = 1 } : (!openfhe.crypto_context, !ct) -> !ct + %14 = openfhe.add %arg0, %11, %13 : (!openfhe.crypto_context, !ct, !ct) -> !ct %cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<32xi16> - %15 = openfhe.make_packed_plaintext %arg0, %cst : (!openfhe.crypto_context, tensor<32xi16>) -> !tensor_pt_ty - %16 = openfhe.mul_plain %arg0, %14, %15 : (!openfhe.crypto_context, !tensor_ct_ty, !tensor_pt_ty) -> !tensor_ct_ty - %18 = openfhe.rot %arg0, %16 { index = 31 } : (!openfhe.crypto_context, !tensor_ct_ty) -> !tensor_ct_ty - %19 = lwe.reinterpret_application_data %18 : !tensor_ct_ty to !scalar_ct_ty - return %19 : !scalar_ct_ty + %15 = openfhe.make_packed_plaintext %arg0, %cst : (!openfhe.crypto_context, tensor<32xi16>) -> !pt + %16 = openfhe.mul_plain %arg0, %14, %15 : (!openfhe.crypto_context, !ct, !pt) -> !ct + %18 = openfhe.rot %arg0, %16 { index = 31 } : (!openfhe.crypto_context, !ct) -> !ct + return %18 : !ct } -func.func @simple_sum__encrypt(%arg0: !openfhe.crypto_context, %arg1: tensor<32xi16>, %arg2: !openfhe.public_key) -> !tensor_ct_ty { - %0 = openfhe.make_packed_plaintext %arg0, %arg1 : (!openfhe.crypto_context, tensor<32xi16>) -> !tensor_pt_ty - %1 = openfhe.encrypt %arg0, %0, %arg2 : (!openfhe.crypto_context, !tensor_pt_ty, !openfhe.public_key) -> !tensor_ct_ty - return %1 : !tensor_ct_ty +func.func @simple_sum__encrypt(%arg0: !openfhe.crypto_context, %arg1: tensor<32xi16>, %arg2: !openfhe.public_key) -> !ct { + %0 = openfhe.make_packed_plaintext %arg0, %arg1 : (!openfhe.crypto_context, tensor<32xi16>) -> !pt + %1 = openfhe.encrypt %arg0, %0, %arg2 : (!openfhe.crypto_context, !pt, !openfhe.public_key) -> !ct + return %1 : !ct } -func.func @simple_sum__decrypt(%arg0: !openfhe.crypto_context, %arg1: !scalar_ct_ty, %arg2: !openfhe.private_key) -> i16 { - %0 = openfhe.decrypt %arg0, %arg1, %arg2 : (!openfhe.crypto_context, !scalar_ct_ty, !openfhe.private_key) -> !scalar_pt_ty - %1 = lwe.rlwe_decode %0 {encoding = #full_crt_packing_encoding, ring = #ring_Z65537_i64_1_x32_} : !scalar_pt_ty -> i16 +func.func @simple_sum__decrypt(%arg0: !openfhe.crypto_context, %arg1: !ct, %arg2: !openfhe.private_key) -> i16 { + %0 = openfhe.decrypt %arg0, %arg1, %arg2 : (!openfhe.crypto_context, !ct, !openfhe.private_key) -> !pt + %1 = openfhe.decode %0 : !pt -> i16 return %1 : i16 } diff --git a/tests/Examples/openfhe/bfv/dot_product_8_debug/BUILD b/tests/Examples/openfhe/bfv/dot_product_8_debug/BUILD index 581f570aae..f1b494e18f 100644 --- a/tests/Examples/openfhe/bfv/dot_product_8_debug/BUILD +++ b/tests/Examples/openfhe/bfv/dot_product_8_debug/BUILD @@ -1,7 +1,7 @@ # See README.md for setup required to run these tests -load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test") load("@rules_cc//cc:cc_library.bzl", "cc_library") +# load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test") package(default_applicable_licenses = ["@heir//:license"]) @@ -15,15 +15,16 @@ cc_library( ], ) -openfhe_end_to_end_test( - name = "dot_product_8_debug_test", - generated_lib_header = "dot_product_8_debug_lib.h", - heir_opt_flags = [ - "--annotate-module=backend=openfhe scheme=bfv", - "--mlir-to-bfv=ciphertext-degree=8192 annotate-noise-bound=true", - "--scheme-to-openfhe=insert-debug-handler-calls=true", - ], - mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", - test_src = "dot_product_8_debug_test.cpp", - deps = [":debug_helper"], -) +# TODO(#2514): re-enable with fixed debug helpers +# openfhe_end_to_end_test( +# name = "dot_product_8_debug_test", +# generated_lib_header = "dot_product_8_debug_lib.h", +# heir_opt_flags = [ +# "--annotate-module=backend=openfhe scheme=bfv", +# "--mlir-to-bfv=ciphertext-degree=8192 annotate-noise-bound=true", +# "--scheme-to-openfhe=insert-debug-handler-calls=true", +# ], +# mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", +# test_src = "dot_product_8_debug_test.cpp", +# deps = [":debug_helper"], +# ) diff --git a/tests/Examples/openfhe/bfv/dot_product_8_debug/dot_product_8_debug_test.cpp b/tests/Examples/openfhe/bfv/dot_product_8_debug/dot_product_8_debug_test.cpp index 4829d23a85..239e839b5e 100644 --- a/tests/Examples/openfhe/bfv/dot_product_8_debug/dot_product_8_debug_test.cpp +++ b/tests/Examples/openfhe/bfv/dot_product_8_debug/dot_product_8_debug_test.cpp @@ -3,6 +3,7 @@ #include #include "gtest/gtest.h" // from @googletest +#include "tests/Examples/openfhe/bfv/dot_product_8_debug/debug_helper.h" // Generated headers (block clang-format from messing up order) #include "tests/Examples/openfhe/bfv/dot_product_8_debug/dot_product_8_debug_lib.h" diff --git a/tests/Examples/openfhe/bgv/binops/binops.mlir b/tests/Examples/openfhe/bgv/binops/binops.mlir index 0684046154..a5daaa8ef8 100644 --- a/tests/Examples/openfhe/bgv/binops/binops.mlir +++ b/tests/Examples/openfhe/bgv/binops/binops.mlir @@ -1,18 +1,8 @@ // RUN: heir-translate %s --emit-openfhe-pke | FileCheck %s !cc = !openfhe.crypto_context +!ct = !openfhe.ciphertext -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> -#key = #lwe.key<> -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> -!rns_L0_ = !rns.rns -#ring_rns_L0_1_x8_ = #polynomial.ring> -#ring_Z65537_i64_1_x8_ = #polynomial.ring> -#inverse_canonical_encoding = #lwe.inverse_canonical_encoding -#plaintext_space = #lwe.plaintext_space -#ciphertext_space_L0_ = #lwe.ciphertext_space -!ct = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> module attributes {scheme.ckks} { func.func @test_binops(%cc : !cc, %input1 : !ct, %input2 : !ct) -> !ct { %add_res = openfhe.add %cc, %input1, %input2 : (!cc, !ct, !ct) -> !ct diff --git a/tests/Examples/openfhe/bgv/ciphertext_plaintext_ops/ciphertext_plaintext_ops.mlir b/tests/Examples/openfhe/bgv/ciphertext_plaintext_ops/ciphertext_plaintext_ops.mlir index cda37e7d0f..db725e3d4e 100644 --- a/tests/Examples/openfhe/bgv/ciphertext_plaintext_ops/ciphertext_plaintext_ops.mlir +++ b/tests/Examples/openfhe/bgv/ciphertext_plaintext_ops/ciphertext_plaintext_ops.mlir @@ -1,19 +1,8 @@ // RUN: heir-translate %s --emit-openfhe-pke | FileCheck %s !cc = !openfhe.crypto_context - -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> -#key = #lwe.key<> -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> -!rns_L0_ = !rns.rns -#ring_rns_L0_1_x8_ = #polynomial.ring> -#ring_Z65537_i64_1_x8_ = #polynomial.ring> -#inverse_canonical_encoding = #lwe.inverse_canonical_encoding -#plaintext_space = #lwe.plaintext_space -#ciphertext_space_L0_ = #lwe.ciphertext_space -!pt = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> -!ct = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> +!pt = !openfhe.plaintext +!ct = !openfhe.ciphertext // [(input1 + input2) - input3] * input4 module attributes {scheme.bgv} { diff --git a/tests/Examples/openfhe/bgv/dot_product_8_debug/BUILD b/tests/Examples/openfhe/bgv/dot_product_8_debug/BUILD index 88a47bce58..ccabc055d0 100644 --- a/tests/Examples/openfhe/bgv/dot_product_8_debug/BUILD +++ b/tests/Examples/openfhe/bgv/dot_product_8_debug/BUILD @@ -1,18 +1,19 @@ # See README.md for setup required to run these tests -load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test") +# load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test") package(default_applicable_licenses = ["@heir//:license"]) -openfhe_end_to_end_test( - name = "dot_product_8_debug_test", - generated_lib_header = "dot_product_8_debug_lib.h", - heir_opt_flags = [ - "--annotate-module=backend=openfhe scheme=bgv", - "--mlir-to-bgv=ciphertext-degree=8192 annotate-noise-bound=true", - "--scheme-to-openfhe=insert-debug-handler-calls=true", - ], - mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", - tags = ["notap"], - test_src = "dot_product_8_debug_test.cpp", -) +# TODO(#2514): re-enable with fixed debug helpers +# openfhe_end_to_end_test( +# name = "dot_product_8_debug_test", +# generated_lib_header = "dot_product_8_debug_lib.h", +# heir_opt_flags = [ +# "--annotate-module=backend=openfhe scheme=bgv", +# "--mlir-to-bgv=ciphertext-degree=8192 annotate-noise-bound=true", +# "--scheme-to-openfhe=insert-debug-handler-calls=true", +# ], +# mlir_src = "@heir//tests/Examples/common:dot_product_8.mlir", +# tags = ["notap"], +# test_src = "dot_product_8_debug_test.cpp", +# ) diff --git a/tests/Examples/openfhe/ckks/dot_product_8f_debug/BUILD b/tests/Examples/openfhe/ckks/dot_product_8f_debug/BUILD index 4c12557ab9..8210d7efda 100644 --- a/tests/Examples/openfhe/ckks/dot_product_8f_debug/BUILD +++ b/tests/Examples/openfhe/ckks/dot_product_8f_debug/BUILD @@ -1,23 +1,24 @@ # See README.md for setup required to run these tests -load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test") +# load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test") package(default_applicable_licenses = ["@heir//:license"]) -openfhe_end_to_end_test( - name = "dot_product_8f_debug_test", - data = [ - "@heir//tests/Examples/plaintext/dot_product_f_debug:dot_product_8f_debug.log", - ], - generated_lib_header = "dot_product_8f_debug_lib.h", - heir_opt_flags = [ - "--annotate-module=backend=openfhe scheme=ckks", - "--mlir-to-ckks=ciphertext-degree=8 \ - plaintext-execution-result-file-name=$(location @heir//tests/Examples/plaintext/dot_product_f_debug:dot_product_8f_debug.log)", - "--scheme-to-openfhe=insert-debug-handler-calls=true", - ], - heir_translate_flags = [], - mlir_src = "@heir//tests/Examples/common:dot_product_8f.mlir", - tags = ["notap"], - test_src = "dot_product_8f_debug_test.cpp", -) +# TODO(#2514): re-enable with fixed debug helpers +# openfhe_end_to_end_test( +# name = "dot_product_8f_debug_test", +# data = [ +# "@heir//tests/Examples/plaintext/dot_product_f_debug:dot_product_8f_debug.log", +# ], +# generated_lib_header = "dot_product_8f_debug_lib.h", +# heir_opt_flags = [ +# "--annotate-module=backend=openfhe scheme=ckks", +# "--mlir-to-ckks=ciphertext-degree=8 \ +# plaintext-execution-result-file-name=$(location @heir//tests/Examples/plaintext/dot_product_f_debug:dot_product_8f_debug.log)", +# "--scheme-to-openfhe=insert-debug-handler-calls=true", +# ], +# heir_translate_flags = [], +# mlir_src = "@heir//tests/Examples/common:dot_product_8f.mlir", +# tags = ["notap"], +# test_src = "dot_product_8f_debug_test.cpp", +# ) diff --git a/tests/Examples/openfhe/ckks/lenet/BUILD b/tests/Examples/openfhe/ckks/lenet/BUILD index 3c6c1f7d8a..93dada1cfd 100644 --- a/tests/Examples/openfhe/ckks/lenet/BUILD +++ b/tests/Examples/openfhe/ckks/lenet/BUILD @@ -26,6 +26,7 @@ cc_binary( srcs = ["lenet_main.cpp"], tags = [ "manual", + "nofastbuild", "notap", ], deps = [ @@ -47,6 +48,10 @@ heir_opt( "--torch-linalg-to-ckks=ciphertext-degree=1024", "--scheme-to-openfhe", ], + tags = [ + "nofastbuild", + "requires-mem:28g", + ], ) cc_library( diff --git a/tests/Examples/openfhe/ckks/simple_ckks_bootstrapping/BUILD b/tests/Examples/openfhe/ckks/simple_ckks_bootstrapping/BUILD index 5d33a277cc..bb8bf706df 100644 --- a/tests/Examples/openfhe/ckks/simple_ckks_bootstrapping/BUILD +++ b/tests/Examples/openfhe/ckks/simple_ckks_bootstrapping/BUILD @@ -8,7 +8,7 @@ openfhe_end_to_end_test( name = "simple_ckks_bootstrapping_test", generated_lib_header = "simple_ckks_bootstrapping_lib.h", heir_opt_flags = [ - "--openfhe-configure-crypto-context=insecure=true ring-dim=128", + "--openfhe-configure-crypto-context=insecure=true ring-dim=128 mul-depth=22", ], heir_translate_flags = [ "--openfhe-include-type=source-relative", diff --git a/tests/Examples/openfhe/ckks/simple_ckks_bootstrapping/simple_ckks_bootstrapping.mlir b/tests/Examples/openfhe/ckks/simple_ckks_bootstrapping/simple_ckks_bootstrapping.mlir index 8eddc97f8c..c6ec760781 100644 --- a/tests/Examples/openfhe/ckks/simple_ckks_bootstrapping/simple_ckks_bootstrapping.mlir +++ b/tests/Examples/openfhe/ckks/simple_ckks_bootstrapping/simple_ckks_bootstrapping.mlir @@ -1,35 +1,9 @@ -!Z1005037682689_i64_ = !mod_arith.int<1005037682689 : i64> -!Z1032955396097_i64_ = !mod_arith.int<1032955396097 : i64> -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> -#inverse_canonical_encoding = #lwe.inverse_canonical_encoding -#key = #lwe.key<> -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> -#modulus_chain_L5_C1_ = #lwe.modulus_chain, current = 1> -#modulus_chain_L5_C2_ = #lwe.modulus_chain, current = 2> -!rns_L0_ = !rns.rns -!rns_L1_ = !rns.rns -!rns_L2_ = !rns.rns -#ring_Z65537_i64_1_x32_ = #polynomial.ring> -#plaintext_space = #lwe.plaintext_space -#ring_rns_L0_1_x32_ = #polynomial.ring> -#ring_rns_L1_1_x32_ = #polynomial.ring> -#ring_rns_L2_1_x32_ = #polynomial.ring> -#ciphertext_space_L0_ = #lwe.ciphertext_space -#ciphertext_space_L1_ = #lwe.ciphertext_space -#ciphertext_space_L2_ = #lwe.ciphertext_space - -!ct_L0_ = !lwe.lwe_ciphertext>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> -!ct_L1_ = !lwe.lwe_ciphertext>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L1_, key = #key, modulus_chain = #modulus_chain_L5_C1_> -!ct_L2_ = !lwe.lwe_ciphertext>, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L2_, key = #key, modulus_chain = #modulus_chain_L5_C2_> +!ct = !openfhe.ciphertext +!pt = !openfhe.plaintext module attributes {scheme.ckks} { - func.func @simple_ckks_bootstrapping(%cc: !openfhe.crypto_context, %ct: !ct_L2_) -> !ct_L2_ { - // in FLEXIBLEAUTOEXT mode, openfhe won't execute those mod_reduce - // added here just for type conversion - %0 = openfhe.mod_reduce %cc, %ct : (!openfhe.crypto_context, !ct_L2_) -> !ct_L1_ - %1 = openfhe.mod_reduce %cc, %0 : (!openfhe.crypto_context, !ct_L1_) -> !ct_L0_ - %2 = openfhe.bootstrap %cc, %1 : (!openfhe.crypto_context, !ct_L0_) -> !ct_L2_ - return %2 : !ct_L2_ + func.func @simple_ckks_bootstrapping(%cc: !openfhe.crypto_context, %ct: !ct) -> !ct { + %2 = openfhe.bootstrap %cc, %ct : (!openfhe.crypto_context, !ct) -> !ct + return %2 : !ct } } diff --git a/tests/Transforms/annotate_muldepth/BUILD b/tests/Transforms/annotate_muldepth/BUILD new file mode 100644 index 0000000000..c571e6fc6d --- /dev/null +++ b/tests/Transforms/annotate_muldepth/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Transforms/annotate_muldepth/doctest.mlir b/tests/Transforms/annotate_muldepth/doctest.mlir new file mode 100644 index 0000000000..47448b45ca --- /dev/null +++ b/tests/Transforms/annotate_muldepth/doctest.mlir @@ -0,0 +1,30 @@ +// RUN: heir-opt --annotate-muldepth %s | FileCheck %s + +// CHECK: func.func @doctest +// CHECK-SAME: !secret.secret {secret.mul_depth = 0 +// CHECK: secret.generic +// CHECK-SAME: !secret.secret {secret.mul_depth = 0 +// CHECK: ^body +// CHECK-NEXT: arith.muli +// CHECK-NEXT: arith.muli +// CHECK-SAME: secret.mul_depth = 1 +// CHECK-NEXT: arith.muli +// CHECK-SAME: secret.mul_depth = 1 +// CHECK-NEXT: arith.muli +// CHECK-SAME: secret.mul_depth = 2 +// CHECK-NEXT: arith.muli +// CHECK-SAME: secret.mul_depth = 3 +// CHECK-NEXT: secret.yield + +func.func @doctest(%x: !secret.secret, %y: i32) -> !secret.secret { + %0 = secret.generic(%x: !secret.secret, %y: i32) { + ^body(%secret_val: i32, %public_val: i32): + %1 = arith.muli %public_val, %public_val : i32 + %2 = arith.muli %secret_val, %secret_val : i32 + %3 = arith.muli %secret_val, %public_val : i32 + %4 = arith.muli %2, %3 : i32 + %5 = arith.muli %4, %3 : i32 + secret.yield %5 : i32 + } -> !secret.secret + return %0 : !secret.secret +} diff --git a/tests/Transforms/annotate_muldepth/openfhe.mlir b/tests/Transforms/annotate_muldepth/openfhe.mlir new file mode 100644 index 0000000000..dac0ffbdd7 --- /dev/null +++ b/tests/Transforms/annotate_muldepth/openfhe.mlir @@ -0,0 +1,27 @@ +// RUN: heir-opt --annotate-muldepth --mlir-print-local-scope %s | FileCheck %s + +// Tests other secret-like type by interface + +// CHECK: func.func @doctest +// CHECK-SAME: !openfhe.ciphertext {secret.mul_depth = 0 +// CHECK-NEXT: openfhe.mul +// CHECK-SAME: secret.mul_depth = 1 +// CHECK-NEXT: openfhe.mul_plain +// CHECK-SAME: secret.mul_depth = 1 +// CHECK-NEXT: openfhe.mul +// CHECK-SAME: secret.mul_depth = 2 +// CHECK-NEXT: openfhe.mul +// CHECK-SAME: secret.mul_depth = 3 +// CHECK-NEXT: return + +!ct = !openfhe.ciphertext +!pt = !openfhe.plaintext +!cc = !openfhe.crypto_context + +func.func @doctest(%cc: !cc, %secret_val: !ct, %public_val: !pt) -> !ct { + %2 = openfhe.mul %cc, %secret_val, %secret_val : (!cc, !ct, !ct) -> !ct + %3 = openfhe.mul_plain %cc, %secret_val, %public_val : (!cc, !ct, !pt) -> !ct + %4 = openfhe.mul %cc, %2, %3 : (!cc, !ct, !ct) -> !ct + %5 = openfhe.mul %cc, %4, %3 : (!cc, !ct, !ct) -> !ct + return %5 : !ct +} diff --git a/tests/Transforms/forward_insert_to_extract/forward_insert_to_extract.mlir b/tests/Transforms/forward_insert_to_extract/forward_insert_to_extract.mlir index 0e049970b3..bda1b89c86 100644 --- a/tests/Transforms/forward_insert_to_extract/forward_insert_to_extract.mlir +++ b/tests/Transforms/forward_insert_to_extract/forward_insert_to_extract.mlir @@ -2,19 +2,8 @@ !cc = !openfhe.crypto_context - -!Z1095233372161_i64_ = !mod_arith.int<1095233372161 : i64> -!Z65537_i64_ = !mod_arith.int<65537 : i64> -#key = #lwe.key<> -#modulus_chain_L5_C0_ = #lwe.modulus_chain, current = 0> -!rns_L0_ = !rns.rns -#ring_rns_L0_1_x16_ = #polynomial.ring> -#ring_Z65537_i64_1_x16_ = #polynomial.ring> -#inverse_canonical_encoding = #lwe.inverse_canonical_encoding -#plaintext_space = #lwe.plaintext_space -#ciphertext_space_L0_ = #lwe.ciphertext_space -!pt = !lwe.lwe_plaintext, plaintext_space = #plaintext_space> -!ct = !lwe.lwe_ciphertext, plaintext_space = #plaintext_space, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L5_C0_> +!ct = !openfhe.ciphertext +!pt = !openfhe.plaintext // CHECK: @successful_forwarding diff --git a/tools/BUILD b/tools/BUILD index 58b6b083d7..6b1aee1ee6 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -106,6 +106,7 @@ cc_binary( "@heir//lib/Transforms/ActivationCanonicalizations", "@heir//lib/Transforms/AddClientInterface", "@heir//lib/Transforms/AnnotateModule", + "@heir//lib/Transforms/AnnotateMulDepth", "@heir//lib/Transforms/AnnotateSecretness", "@heir//lib/Transforms/ApplyFolders", "@heir//lib/Transforms/CompareToSignRewrite", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 363ce7df26..909dfc6fdd 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -59,6 +59,7 @@ #include "lib/Transforms/ActivationCanonicalizations/ActivationCanonicalizations.h" #include "lib/Transforms/AddClientInterface/AddClientInterface.h" #include "lib/Transforms/AnnotateModule/AnnotateModule.h" +#include "lib/Transforms/AnnotateMulDepth/AnnotateMulDepth.h" #include "lib/Transforms/AnnotateSecretness/AnnotateSecretness.h" #include "lib/Transforms/ApplyFolders/ApplyFolders.h" #include "lib/Transforms/CompareToSignRewrite/CompareToSignRewrite.h" @@ -284,6 +285,7 @@ int main(int argc, char** argv) { registerDropUnitDims(); registerAnnotateModulePasses(); registerAnnotateSecretnessPasses(); + registerAnnotateMulDepthPasses(); registerApplyFoldersPasses(); registerFoldPlaintextMasksPasses(); registerForwardInsertSliceToExtractSlicePasses(); @@ -369,8 +371,9 @@ int main(int argc, char** argv) { // Interfaces in HEIR secret::registerBufferizableOpInterfaceExternalModels(registry); rns::registerExternalRNSTypeInterfaces(registry); - registerOperandAndResultAttrInterface(registry); + registerIncreasesMulDepthOpInterface(registry); registerLayoutConversionHoistableInterface(registry); + registerOperandAndResultAttrInterface(registry); registerOperandLayoutRequirementOpInterface(registry); PassPipelineRegistration<>(