From 02aa1cb6cb4990e1ec06ac0b86812c83c31b3f3e Mon Sep 17 00:00:00 2001 From: SaiHoCao Date: Mon, 24 Feb 2025 14:06:33 +0800 Subject: [PATCH] Dev csh (#1) * add:GemmFMA&&Test * add:QKMulFMASingle test bug * update but still error * add OutValue&Single * fix && AttnScores pass * Single FxpError &&SingleQueue Error * add fork test * fix OutValue and all passed but some TODO should fix * chage verilator version * change gcc version --------- Co-authored-by: pyfirstcsh <8295488+cao-shuai-hu@user.noreply.gitee.com> --- .github/workflows/test.yml | 3 +- src/main/scala/kernel/alu/AttnScores.scala | 674 +++++++++++++++ src/main/scala/kernel/alu/Gemm.scala | 2 +- src/main/scala/kernel/alu/GemmFMA.scala | 423 ++++++++++ src/main/scala/kernel/alu/OutValue.scala | 199 +++++ .../scala/kernel/alu/AttnScoresTest.scala | 764 ++++++++++++++++++ src/test/scala/kernel/alu/GemmFMATest.scala | 444 ++++++++++ src/test/scala/kernel/alu/OutValueTest.scala | 191 +++++ src/test/scala/kernel/alu/utils.scala | 214 +++++ 9 files changed, 2912 insertions(+), 2 deletions(-) create mode 100644 src/main/scala/kernel/alu/AttnScores.scala create mode 100644 src/main/scala/kernel/alu/GemmFMA.scala create mode 100644 src/main/scala/kernel/alu/OutValue.scala create mode 100644 src/test/scala/kernel/alu/AttnScoresTest.scala create mode 100644 src/test/scala/kernel/alu/GemmFMATest.scala create mode 100644 src/test/scala/kernel/alu/OutValueTest.scala create mode 100644 src/test/scala/kernel/alu/utils.scala diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 28553de..aed9351 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,7 +29,8 @@ jobs: apps: sbt - name: Setup Dependencies run: | - sudo apt-get install ccache + sudo apt-get install ccache g++-11 # 安装 GCC 11 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 11 - name: Get Cached Verilator id: get-cached-verilator uses: actions/cache@v4 diff --git a/src/main/scala/kernel/alu/AttnScores.scala b/src/main/scala/kernel/alu/AttnScores.scala new file mode 100644 index 0000000..60df6d3 --- /dev/null +++ b/src/main/scala/kernel/alu/AttnScores.scala @@ -0,0 +1,674 @@ +package kernel.alu + +import chisel3._ +import chisel3.util._ +import kernel.alu.GEMMDataType +import kernel.alu.DataWidthConfig +import kernel.utils.DebugLog +import kernel.deprecated.PE + +// QKGenWithReg: use two GEMMFMATotal to get Q and K +// whthin use Reg to store Q and K +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: Query: m * n +// output: Key: m * n +class QKGenWithReg( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val Query = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + val Key = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + // val readyReg = RegInit(true.B) + // io.inputToken.ready := readyReg + // io.weightQ.ready := readyReg + // io.weightK.ready := readyReg + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B + io.Key.valid := false.B + io.Key.bits := DontCare + io.Query.valid := false.B + io.Query.bits := DontCare + + val qGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + val kGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + + val Qreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + val Kreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + + qGen.io.matrixA.bits := io.inputToken.bits + qGen.io.matrixA.valid := io.inputToken.valid + qGen.io.matrixB.bits := io.weightQ.bits + qGen.io.matrixB.valid := io.weightQ.valid + qGen.io.results.ready := false.B + + kGen.io.matrixA.bits := io.inputToken.bits + kGen.io.matrixA.valid := io.inputToken.valid + kGen.io.matrixB.bits := io.weightK.bits + kGen.io.matrixB.valid := io.weightK.valid + kGen.io.results.ready := false.B + + object state extends ChiselEnum { + val idle, gen, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + // readyReg := false.B + io.inputToken.ready := false.B + io.weightQ.ready := false.B + io.weightK.ready := false.B + stateReg := state.gen + } + } + is(state.gen) { + qGen.io.results.ready := true.B + kGen.io.results.ready := true.B + when(qGen.io.results.valid && kGen.io.results.valid) { + Qreg := qGen.io.results.bits + Kreg := kGen.io.results.bits + stateReg := state.done + } + } + is(state.done) { + qGen.io.results.ready := false.B + kGen.io.results.ready := false.B + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B + // readyReg := true.B + io.Query.valid := true.B + io.Key.valid := true.B + io.Query.bits := Qreg + io.Key.bits := Kreg + stateReg := state.idle + } + } +} + +// QKGen: use two GEMMFMATotal to get Q and K +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: Query: m * n +// output: Key: m * n +class QKGen( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val Query = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + val Key = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + // val readyReg = RegInit(true.B) + // io.inputToken.ready := readyReg + // io.weightQ.ready := readyReg + // io.weightK.ready := readyReg + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B + io.Key.valid := false.B + io.Key.bits := DontCare + io.Query.valid := false.B + io.Query.bits := DontCare + + val qGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + val kGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + + qGen.io.matrixA.bits := io.inputToken.bits + qGen.io.matrixA.valid := io.inputToken.valid + qGen.io.matrixB.bits := io.weightQ.bits + qGen.io.matrixB.valid := io.weightQ.valid + qGen.io.results.ready := false.B + + kGen.io.matrixA.bits := io.inputToken.bits + kGen.io.matrixA.valid := io.inputToken.valid + kGen.io.matrixB.bits := io.weightK.bits + kGen.io.matrixB.valid := io.weightK.valid + kGen.io.results.ready := false.B + + object state extends ChiselEnum { + val idle, gen, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + // readyReg := false.B + io.inputToken.ready := false.B + io.weightQ.ready := false.B + io.weightK.ready := false.B + stateReg := state.gen + } + } + is(state.gen) { + qGen.io.results.ready := true.B + kGen.io.results.ready := true.B + when(qGen.io.results.valid && kGen.io.results.valid) { + stateReg := state.done + } + } + is(state.done) { + qGen.io.results.ready := false.B + kGen.io.results.ready := false.B + // readyReg := true.B + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B + io.Query.valid := true.B + io.Key.valid := true.B + io.Query.bits := qGen.io.results.bits + io.Key.bits := kGen.io.results.bits + stateReg := state.idle + } + } +} + +// QKMulWithReg: use GEMMFMATotal to get scores +// input: Query: m * n +// input: Key: m * n +// output: scores: m * m +class QKMulWithReg( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val Query = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val Key = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.Query.valid && io.Key.valid + + // val readyReg = RegInit(true.B) + // io.Query.ready := readyReg + // io.Key.ready := readyReg + io.Query.ready := true.B + io.Key.ready := true.B + io.scores.valid := false.B + io.scores.bits := DontCare + + val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) + val scoresReg = Reg(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + + QK_TMul.io.matrixA.valid := io.Query.valid + QK_TMul.io.matrixA.bits := io.Query.bits + QK_TMul.io.matrixB.valid := io.Key.valid + QK_TMul.io.matrixB.bits := VecInit(io.Key.bits.transpose.map(VecInit(_))) + QK_TMul.io.results.ready := false.B + + object state extends ChiselEnum { + val idle, mul, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + // readyReg := false.B + io.Query.ready := false.B + stateReg := state.mul + } + } + is(state.mul) { + QK_TMul.io.results.ready := true.B + when(QK_TMul.io.results.valid) { + scoresReg := QK_TMul.io.results.bits + stateReg := state.done + } + } + is(state.done) { + QK_TMul.io.results.ready := false.B + // readyReg := true.B + io.Query.ready := true.B + io.scores.valid := true.B + io.scores.bits := scoresReg + stateReg := state.idle + } + } +} + +// QKMul: use GEMMFMATotal to get scores +// input: Query: m * n +// input: Key: m * n +// output: scores: m * m +class QKMul( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val Query = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val Key = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.Query.valid && io.Key.valid + + // val readyReg = RegInit(true.B) + // io.Query.ready := readyReg + // io.Key.ready := readyReg + io.Query.ready := true.B + io.Key.ready := true.B + io.scores.valid := false.B + io.scores.bits := DontCare + + val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) + + QK_TMul.io.matrixA.valid := io.Query.valid + QK_TMul.io.matrixA.bits := io.Query.bits + QK_TMul.io.matrixB.valid := io.Key.valid + QK_TMul.io.matrixB.bits := VecInit(io.Key.bits.transpose.map(VecInit(_))) + QK_TMul.io.results.ready := false.B + + object state extends ChiselEnum { + val idle, mul, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + // readyReg := false.B + io.Query.ready := false.B + stateReg := state.mul + } + } + is(state.mul) { + QK_TMul.io.results.ready := true.B + when(QK_TMul.io.results.valid) { + stateReg := state.done + } + } + is(state.done) { + QK_TMul.io.results.ready := false.B + // readyReg := true.B + io.Query.ready := true.B + io.scores.valid := true.B + io.scores.bits := QK_TMul.io.results.bits + stateReg := state.idle + } + } +} +// QKMulSingle: use GEMMFMASingle to get scores by row +// input: Query: m * n +// input: Key: m * n +// output: scores: 1 * m +// output: done: Bool +class QKMulSingle( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val Query = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val Key = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val curRowScores = Decoupled(new curRowIndex(m, m)) + val done = Output(Bool()) + }) + + val dataValid = io.Query.valid && io.Key.valid + + val readyReg = RegInit(true.B) + io.Query.ready := readyReg + io.Key.ready := readyReg + io.curRowScores.valid := false.B + io.curRowScores.bits := DontCare + io.done := false.B + + val curRowIndexReg = Reg(new curRowIndex(m, m)) + + val QK_TMul = Module(new GEMMFMASingle(m, n, m, peCount, gemmType)) + // val QK_TMul = Module(new GEMMSingleQueue(m, n, m, peCount, gemmType)) + + QK_TMul.io.matrixA.valid := io.Query.valid + QK_TMul.io.matrixA.bits := io.Query.bits + QK_TMul.io.matrixB.valid := io.Key.valid + QK_TMul.io.matrixB.bits := VecInit(io.Key.bits.transpose.map(VecInit(_))) + QK_TMul.io.curRow.ready := false.B + + object state extends ChiselEnum { + val idle, mul, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.mul + } + } + is(state.mul) { + QK_TMul.io.curRow.ready := true.B + when(QK_TMul.io.curRow.valid) { + curRowIndexReg := QK_TMul.io.curRow.bits + stateReg := state.update + } + } + is(state.update) { + QK_TMul.io.curRow.ready := false.B + io.curRowScores.valid := true.B + io.curRowScores.bits := curRowIndexReg + when(QK_TMul.io.done) { + stateReg := state.done + }.otherwise { + stateReg := state.mul + } + } + is(state.done) { + readyReg := true.B + io.done := true.B + stateReg := state.idle + } + } +} + +// AttnScores: use QKGen to get Q and K, then use QKMul to get scores +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: scores: m * m +class AttnScores( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + // val readyReg = RegInit(true.B) + // io.inputToken.ready := readyReg + // io.weightQ.ready := readyReg + // io.weightK.ready := readyReg + + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B + io.scores.valid := false.B + io.scores.bits := DontCare + + // val scoresReg = Reg(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + // val QKGen = Module(new QKGenWithReg(m, k, n, peCount, gemmType)) + val QKGen = Module(new QKGen(m, k, n, peCount, gemmType)) + + QKGen.io.inputToken.valid := io.inputToken.valid + QKGen.io.inputToken.bits := io.inputToken.bits + QKGen.io.weightQ.valid := io.weightQ.valid + QKGen.io.weightQ.bits := io.weightQ.bits + QKGen.io.weightK.valid := io.weightK.valid + QKGen.io.weightK.bits := io.weightK.bits + + QKGen.io.Query.ready := false.B + QKGen.io.Key.ready := false.B + + // val QKMul = Module(new QKMulWithReg(m, n, peCount, gemmType)) + val QKMul = Module(new QKMul(m, n, peCount, gemmType)) + + QKMul.io.Query.valid := QKGen.io.Query.valid + QKMul.io.Query.bits := QKGen.io.Query.bits + QKMul.io.Key.valid := QKGen.io.Key.valid + QKMul.io.Key.bits := QKGen.io.Key.bits + QKMul.io.scores.ready := false.B + + object state extends ChiselEnum { + val idle, gen, mul, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + // readyReg := false.B + io.inputToken.ready := false.B + io.weightQ.ready := false.B + stateReg := state.gen + } + } + is(state.gen) { + QKGen.io.Query.ready := true.B + QKGen.io.Key.ready := true.B + when(QKGen.io.Query.valid && QKGen.io.Key.valid) { + stateReg := state.mul + } + } + is(state.mul) { + QKGen.io.Query.ready := false.B + QKGen.io.Key.ready := false.B + QKMul.io.scores.ready := true.B + when(QKMul.io.scores.valid) { + // scoresReg := QKMul.io.scores.bits + stateReg := state.done + } + } + is(state.done) { + QKMul.io.scores.ready := false.B + // readyReg := true.B + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.scores.valid := true.B + // io.scores.bits := scoresReg + io.scores.bits := QKMul.io.scores.bits + stateReg := state.idle + } + } +} + +// AttnScoresSingle: use QKGen to get Q and K, get scores by row +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: curRowScores: 1 * m +// output: done: Bool +class AttnScoresSingle( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val curRowScores = Decoupled(new curRowIndex(m, m)) + val done = Output(Bool()) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + // val readyReg = RegInit(true.B) + // io.inputToken.ready := readyReg + // io.weightQ.ready := readyReg + // io.weightK.ready := readyReg + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B + io.curRowScores.valid := false.B + io.curRowScores.bits := DontCare + io.done := false.B + + // val QKGen = Module(new QKGenWithReg(m, k, n, peCount, gemmType)) + val QKGen = Module(new QKGen(m, k, n, peCount, gemmType)) + + QKGen.io.inputToken.valid := io.inputToken.valid + QKGen.io.inputToken.bits := io.inputToken.bits + QKGen.io.weightQ.valid := io.weightQ.valid + QKGen.io.weightQ.bits := io.weightQ.bits + QKGen.io.weightK.valid := io.weightK.valid + QKGen.io.weightK.bits := io.weightK.bits + + QKGen.io.Query.ready := false.B + QKGen.io.Key.ready := false.B + + val curRowIndexReg = Reg(new curRowIndex(m, m)) + val QKMul = Module(new GEMMFMASingle(m, n, m, peCount, gemmType)) + + QKMul.io.matrixA.valid := QKGen.io.Query.valid + QKMul.io.matrixA.bits := QKGen.io.Query.bits + QKMul.io.matrixB.valid := QKGen.io.Key.valid + QKMul.io.matrixB.bits := VecInit(QKGen.io.Key.bits.transpose.map(VecInit(_))) + QKMul.io.curRow.ready := false.B + + object state extends ChiselEnum { + val idle, gen, mul, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + // readyReg := false.B + io.inputToken.ready := false.B + io.weightQ.ready := false.B + io.weightK.ready := false.B + stateReg := state.gen + } + } + is(state.gen) { + QKGen.io.Query.ready := true.B + QKGen.io.Key.ready := true.B + when(QKGen.io.Query.valid && QKGen.io.Key.valid) { + stateReg := state.mul + } + } + is(state.mul) { + QKGen.io.Query.ready := false.B + QKGen.io.Key.ready := false.B + QKMul.io.curRow.ready := true.B + when(QKMul.io.curRow.valid) { + curRowIndexReg := QKMul.io.curRow.bits + stateReg := state.update + } + } + is(state.update) { + QKMul.io.curRow.ready := false.B + io.curRowScores.valid := true.B + io.curRowScores.bits := curRowIndexReg + when(QKMul.io.done) { + stateReg := state.done + }.otherwise { + stateReg := state.mul + } + } + is(state.done) { + // readyReg := true.B + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B + io.done := true.B + stateReg := state.idle + } + } +} + +// AttnScoresSingleQueue: use Queue to store scores +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// input: flush: Bool +// output: curRowScores: 1 * m +// output: done: Bool +class AttnScoresSingleQueue( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type, + val bufferSize: Int = 32 +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val flush = Input(Bool()) + val curRowScores = Decoupled(new curRowIndex(m, m)) + val done = Output(Bool()) + }) + + val curBuffer = Module( + new Queue( + new curRowIndex(m, m), + entries = bufferSize, + pipe = true, + flow = false, + useSyncReadMem = false, + hasFlush = true + ) + ) + val doneReg = RegInit(false.B) + + val attnScores = Module(new AttnScoresSingle(m, k, n, peCount, gemmType)) + attnScores.io.inputToken <> io.inputToken + attnScores.io.weightQ <> io.weightQ + attnScores.io.weightK <> io.weightK + + curBuffer.io.flush.get := io.flush + curBuffer.io.enq <> attnScores.io.curRowScores + doneReg := attnScores.io.done + + io.curRowScores <> curBuffer.io.deq + io.done := doneReg + +} diff --git a/src/main/scala/kernel/alu/Gemm.scala b/src/main/scala/kernel/alu/Gemm.scala index c7497fb..8225b70 100644 --- a/src/main/scala/kernel/alu/Gemm.scala +++ b/src/main/scala/kernel/alu/Gemm.scala @@ -8,7 +8,7 @@ import fputil.FPAdd import hardfloat._ trait GEMMAccuracyConfig { - val I: Int = 8 + val I: Int = 10 val F: Int = 0 } diff --git a/src/main/scala/kernel/alu/GemmFMA.scala b/src/main/scala/kernel/alu/GemmFMA.scala new file mode 100644 index 0000000..d57cfe4 --- /dev/null +++ b/src/main/scala/kernel/alu/GemmFMA.scala @@ -0,0 +1,423 @@ +package kernel.alu + +import chisel3._ +import chisel3.util._ +import kernel.alu.GEMMDataType +import kernel.alu.DataWidthConfig +import kernel.utils.DebugLog +import kernel.deprecated.PE + +class curRowIndex( + val m: Int, + val n: Int +)( + implicit config: DataWidthConfig) + extends Bundle { + val index = Output(UInt(log2Ceil(m).W)) //输出的行索引 + val value = Output(Vec(n, UInt(config.outputWidth.W))) //输出的行值 +} + +// input: matrixA_row: one row of matrixA : 1 * k +// input: matrixB_cols: peCount rows of matrixB : k * peCount +// input: reset: clear pes old data +// output: blockResult: one block of result : 1 * peCount +class MultiFMA( + val k: Int, + val peCount: Int, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val matrixA_row = Flipped(Decoupled(Vec(k, UInt(config.inputWidth.W)))) + val matrixB_cols = Flipped(Decoupled(Vec(k, Vec(peCount, UInt(config.inputWidth.W))))) + val blockResult = Decoupled(Vec(peCount, UInt(config.outputWidth.W))) + val reset = Input(Bool()) + }) + + val dataValid = io.matrixA_row.valid && io.matrixB_cols.valid + + val readyReg = RegInit(true.B) + io.matrixA_row.ready := readyReg + io.matrixB_cols.ready := readyReg + io.blockResult.valid := false.B + io.blockResult.bits := DontCare + + val pes = Seq.fill(peCount)(gemmType match { + case GEMMDataType.Fxp => Module(new PEFxp()).io + case GEMMDataType.Fp32 => Module(new PEFp()).io + case GEMMDataType.Fp64 => Module(new PEFp()).io + case _ => throw new IllegalArgumentException("Unsupported GEMM type") + }) + + val optIndex = RegInit(0.U(log2Ceil(k).W)) + val validReg = RegInit(false.B) + + for (i <- 0 until peCount) { + pes(i).reset := io.reset + pes(i).in_h := io.matrixA_row.bits(optIndex) + pes(i).in_v := io.matrixB_cols.bits(optIndex)(i) + io.blockResult.bits(i) := pes(i).out + } + io.blockResult.valid := validReg + + when(dataValid) { + readyReg := false.B + } + when(io.reset) { + optIndex := 0.U + validReg := false.B + }.elsewhen(optIndex === (k - 1).U) { + optIndex := 0.U + validReg := true.B + readyReg := true.B + }.otherwise { + validReg := false.B + optIndex := optIndex + 1.U + } + +} + +class MultiFMA_v2( + val k: Int, + val peCount: Int, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val matrixA_row = Flipped(Decoupled(Vec(k, UInt(config.inputWidth.W)))) + val matrixB_cols = Flipped(Decoupled(Vec(k, Vec(peCount, UInt(config.inputWidth.W))))) + val blockResult = Decoupled(Vec(peCount, UInt(config.outputWidth.W))) + val reset = Input(Bool()) + }) + + val dataValid = io.matrixA_row.valid && io.matrixB_cols.valid + + val readyReg = RegInit(true.B) + io.matrixA_row.ready := readyReg + io.matrixB_cols.ready := readyReg + io.blockResult.valid := false.B + io.blockResult.bits := DontCare + + val pes = Seq.fill(peCount)(gemmType match { + case GEMMDataType.Fxp => Module(new PEFxp()).io + case GEMMDataType.Fp32 => Module(new PEFp()).io + case GEMMDataType.Fp64 => Module(new PEFp()).io + case _ => throw new IllegalArgumentException("Unsupported GEMM type") + }) + + val optIndex = Counter(k) + + pes.foreach { pe => + pe.in_h := 0.U + pe.in_v := 0.U + pe.reset := DontCare + } + + object state extends ChiselEnum { + val idle, compute, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(io.reset) { + optIndex.reset() + for (i <- 0 until peCount) { + pes(i).reset := true.B + } + } + when(dataValid) { + readyReg := false.B + stateReg := state.compute + } + } + is(state.compute) { + for (i <- 0 until peCount) { + pes(i).reset := false.B + pes(i).in_h := io.matrixA_row.bits(optIndex.value) + pes(i).in_v := io.matrixB_cols.bits(optIndex.value)(i) + io.blockResult.bits(i) := pes(i).out + } + stateReg := state.update + } + is(state.update) { + when(optIndex.inc()) { + stateReg := state.done + }.otherwise { + stateReg := state.compute + } + } + is(state.done) { + readyReg := true.B + io.blockResult.valid := true.B + stateReg := state.idle + } + } +} +// input: matrixA: m * k +// input: matrixB: k * n +// output: matrixC: m * n +class GEMMFMATotal( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + require(m % peCount == 0 && k % peCount == 0 && n % peCount == 0, "Matrix dimensions must be divisible by peCount") + val io = IO(new Bundle { + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val results = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.matrixA.valid && io.matrixB.valid + + //TODO:ERR use readyReg and resValid case deadlock + // cases "Exception in thread "chiseltest_thread_2" java.lang.RuntimeException: Deadlock!" Error + // when test fork() and join() in chiseltest + + + //TODO: + // all the moudle use readyReg can case deadlock + // but use true.B ,false.B can't + // when fork() and join() in chiseltest + // maybe something wrong not find + + // val readyReg = RegInit(true.B) + // val resValid = RegInit(false.B) + // io.matrixA.ready := readyReg + // io.matrixB.ready := readyReg + // io.results.valid := resValid + io.matrixA.ready := true.B + io.matrixB.ready := true.B + io.results.valid := false.B + io.results.bits := DontCare + + val multiFMA = Module(new MultiFMA_v2(k, peCount, gemmType)) + + val rowIndex = Counter(m) + val colIndex = Counter(n / peCount) + + multiFMA.io.matrixA_row.valid := io.matrixA.valid + multiFMA.io.matrixA_row.bits := io.matrixA.bits(rowIndex.value) + + multiFMA.io.matrixB_cols.valid := io.matrixB.valid + multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(k) { j => + VecInit(Seq.tabulate(peCount) { i => + io.matrixB.bits(j)((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) + }) + }) //k * peCount size block of matrixB + + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B + + val resultsReg = Reg(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + + object state extends ChiselEnum { + val idle, compute, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + // readyReg := false.B + io.matrixA.ready := false.B + io.matrixB.ready := false.B + stateReg := state.compute + } + } + + is(state.compute) { + multiFMA.io.reset := false.B + multiFMA.io.blockResult.ready := true.B + when(multiFMA.io.blockResult.valid) { + for (i <- 0 until peCount) { + resultsReg(rowIndex.value)((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) := multiFMA.io.blockResult + .bits(i) + } + stateReg := state.update + } + } + + is(state.update) { + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B + when(colIndex.inc()) { + when(rowIndex.inc()) { + stateReg := state.done + }.otherwise { + stateReg := state.compute + } + }.otherwise { + stateReg := state.compute + } + } + is(state.done) { + // resValid := true.B + // readyReg := true.B + io.matrixA.ready := true.B + io.matrixB.ready := true.B + io.results.valid := true.B + io.results.bits := resultsReg + stateReg := state.idle + } + } + +} + +//input: matrixA: m * k +//input: matrixB: k * n +//output: curRowIndex: one row of matrixC: 1 * n and cur row index +//output: done: total matrixC finish flag +class GEMMFMASingle( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + require(m % peCount == 0 && k % peCount == 0 && n % peCount == 0, "Matrix dimensions must be divisible by peCount") + val io = IO(new Bundle { + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val curRow = Decoupled(new curRowIndex(m, n)) + val done = Output(Bool()) + }) + + val dataValid = io.matrixA.valid && io.matrixB.valid + // val readyReg = RegInit(true.B) + // io.matrixA.ready := readyReg + // io.matrixB.ready := readyReg + io.matrixA.ready := true.B + io.matrixB.ready := true.B + io.curRow.valid := false.B + io.curRow.bits := DontCare + io.done := false.B + + val multiFMA = Module(new MultiFMA(k, peCount, gemmType)) + + val rowIndex = Counter(m) + val colIndex = Counter(n / peCount) + + multiFMA.io.matrixA_row.valid := io.matrixA.valid + multiFMA.io.matrixA_row.bits := io.matrixA.bits(rowIndex.value) + + multiFMA.io.matrixB_cols.valid := io.matrixB.valid + multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(k) { j => + VecInit(Seq.tabulate(peCount) { i => + io.matrixB.bits(j)((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) + }) + }) //k * peCount size block of matrixB + + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B + + val curRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) + + object state extends ChiselEnum { + val idle, compute, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + // readyReg := false.B + io.matrixA.ready := false.B + io.matrixB.ready := false.B + stateReg := state.compute + } + } + + is(state.compute) { + multiFMA.io.reset := false.B + multiFMA.io.blockResult.ready := true.B + when(multiFMA.io.blockResult.valid) { + for (i <- 0 until peCount) { + curRowReg((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) := multiFMA.io.blockResult.bits(i) + } + stateReg := state.update + } + } + + is(state.update) { + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B + when(colIndex.inc()) { + io.curRow.valid := true.B + io.curRow.bits.index := rowIndex.value + io.curRow.bits.value := curRowReg + when(rowIndex.inc()) { + stateReg := state.done + }.otherwise { + stateReg := state.compute + } + }.otherwise { + stateReg := state.compute + } + + } + is(state.done) { + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B + io.matrixA.ready := true.B + io.matrixB.ready := true.B + io.done := true.B + // readyReg := true.B + stateReg := state.idle + } + } +} + +class GEMMSingleQueue( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type, + val bufferSize: Int = 32 +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val flush = Input(Bool()) + val curRow = Decoupled(new curRowIndex(m, n)) + val done = Output(Bool()) + }) + + val curBuffer = Module( + new Queue( + new curRowIndex(m, n), + entries = bufferSize, + pipe = true, + flow = false, + useSyncReadMem = false, + hasFlush = true + ) + ) + val doneReg = RegInit(false.B) + val gemm = Module(new GEMMFMASingle(m, k, n, peCount, gemmType)) + gemm.io.matrixA <> io.matrixA + gemm.io.matrixB <> io.matrixB + curBuffer.io.flush.get := io.flush + curBuffer.io.enq <> gemm.io.curRow + doneReg := gemm.io.done + io.curRow <> curBuffer.io.deq + io.done := doneReg + +} diff --git a/src/main/scala/kernel/alu/OutValue.scala b/src/main/scala/kernel/alu/OutValue.scala new file mode 100644 index 0000000..2c875ad --- /dev/null +++ b/src/main/scala/kernel/alu/OutValue.scala @@ -0,0 +1,199 @@ +package kernel.alu + +import chisel3._ +import chisel3.util._ +import kernel.alu.GEMMDataType +import kernel.alu.DataWidthConfig +import kernel.utils.DebugLog +import kernel.deprecated.PE + +// OutValue: get the final output value +// input: AttnWeights: m * m +// input: Value: m * n +// output: AttnOut: m * n +class OutValue( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val Scores = Flipped(Decoupled(Vec(m, Vec(m, UInt(config.inputWidth.W))))) + val Value = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val AttnOut = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.Scores.valid && io.Value.valid + + // val readyReg = RegInit(true.B) + // io.Scores.ready := readyReg + // io.Value.ready := readyReg + io.Scores.ready := true.B + io.Value.ready := true.B + io.AttnOut.valid := false.B + io.AttnOut.bits := DontCare + + val ValueMul = Module(new GEMMFMATotal(m, m, n, peCount, gemmType)) + + ValueMul.io.matrixA.valid := io.Scores.valid + ValueMul.io.matrixA.bits := io.Scores.bits + ValueMul.io.matrixB.valid := io.Value.valid + ValueMul.io.matrixB.bits := io.Value.bits + ValueMul.io.results.ready := false.B + + object state extends ChiselEnum { + val idle, compute, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + // readyReg := false.B + io.Scores.ready := false.B + io.Value.ready := false.B + stateReg := state.compute + } + } + is(state.compute) { + ValueMul.io.results.ready := true.B + when(ValueMul.io.results.valid) { + stateReg := state.done + } + } + is(state.done) { + ValueMul.io.results.ready := false.B + // readyReg := true.B + io.Scores.ready := true.B + io.Value.ready := true.B + io.AttnOut.valid := true.B + io.AttnOut.bits := ValueMul.io.results.bits + stateReg := state.idle + } + } +} + +// OutValue: get the final output value +// input: one row of AttnWeights: 1 * m ,total m rows +// input: Value: m * n +// output: one row of AttnOut: 1 * n ,total m rows +// output: done: Bool +class OutValueSingle( + val m: Int, + val n: Int, + val peCount: Int, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val curScores = Flipped(Decoupled(new curRowIndex(m, m))) + val Value = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val curAttnOut = Decoupled(new curRowIndex(m, n)) + val done = Output(Bool()) + }) + + val dataValid = io.curScores.valid && io.Value.valid + + io.curScores.ready := true.B + io.Value.ready := true.B + io.curAttnOut.valid := false.B + io.curAttnOut.bits := DontCare + io.done := false.B + + val ValueReg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + ValueReg := io.Value.bits + + val multiFMA = Module(new MultiFMA(m, peCount, gemmType)) + + // val rowIndex = Counter(m) + val colIndex = Counter(n / peCount) + + multiFMA.io.matrixA_row.valid := io.curScores.valid + multiFMA.io.matrixA_row.bits := io.curScores.bits.value + + multiFMA.io.matrixB_cols.valid := io.Value.valid + multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(m) { j => + VecInit(Seq.tabulate(peCount) { i => + ValueReg(j)((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) + }) + }) //m * peCount size block of Value + + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B + + val curRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) + val curRowIndex = RegInit(0.U) + + object state extends ChiselEnum { + val idle, compute, update, output, load, done = Value + } + + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + io.Value.ready := false.B + io.curScores.ready := false.B + stateReg := state.compute + } + } + is(state.compute) { + io.curScores.ready := false.B + io.Value.ready := false.B + + multiFMA.io.reset := false.B + multiFMA.io.blockResult.ready := true.B + curRowIndex := io.curScores.bits.index + when(multiFMA.io.blockResult.valid) { + for (i <- 0 until peCount) { + curRowReg((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) := multiFMA.io.blockResult.bits(i) + } + stateReg := state.update + } + } + is(state.update) { + io.curScores.ready := false.B + io.Value.ready := false.B + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B + // io.curAttnOut.valid := false.B + when(colIndex.inc()) { + io.curAttnOut.valid := true.B + io.curAttnOut.bits.index := curRowIndex + io.curAttnOut.bits.value := curRowReg + stateReg := state.output + }.otherwise { + stateReg := state.compute + } + } + is(state.output) { + io.curScores.ready := false.B + io.Value.ready := false.B + when(io.curScores.bits.index === (m - 1).U) { + // print(p"curScores.ready: ${io.curScores.ready}") + stateReg := state.done + }.otherwise { + stateReg := state.load + } + } + is(state.load) { + io.curScores.ready := true.B + io.Value.ready := false.B + io.curAttnOut.valid := false.B + stateReg := state.compute + } + is(state.done) { + io.done := true.B + io.Value.ready := true.B + io.curScores.ready := true.B + io.curAttnOut.valid := false.B + stateReg := state.idle + } + } +} diff --git a/src/test/scala/kernel/alu/AttnScoresTest.scala b/src/test/scala/kernel/alu/AttnScoresTest.scala new file mode 100644 index 0000000..83eee57 --- /dev/null +++ b/src/test/scala/kernel/alu/AttnScoresTest.scala @@ -0,0 +1,764 @@ +package kernel.alu + +import chisel3._ +import chiseltest._ +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.ParallelTestExecution +import scala.reflect.ClassTag +import kernel.alu.{DataWidthConfig, Fp32Config, Fp64Config, FxpConfig, GEMMDataType} +import ujson.Arr +import Utils._ + +class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { + + private def testQKGen[T: Numeric: ClassTag]( + dut: QKGen + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + val caseNum = 10 + + val testCases = Array.tabulate(caseNum) { i => + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val exQuery = mmul(inputToken, weightQ) + val exKey = mmul(inputToken, weightK) + (inputToken, weightQ, weightK, exQuery, exKey) + } + + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inputToken, weightQ, weightK, _, _) = testCases(cnt) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("test case " + cnt + ": inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + cnt += 1 + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, _, exQuery, exKey) = testCases(resCnt) + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.Key.valid.peekBoolean() && dut.io.Query.valid.peekBoolean()) { + dut.io.Key.ready.poke(true.B) + dut.io.Query.ready.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + val outQBigInt = dut.io.Query.bits(row)(col).peekInt() + val outKBigInt = dut.io.Key.bits(row)(col).peekInt() + val outQ = fromBinaryBigInt[T](outQBigInt) + val outK = fromBinaryBigInt[T](outKBigInt) + val expectedQ = exQuery(row)(col) + val expectedK = exKey(row)(col) + checkResult(outQ, expectedQ, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + checkResult(outK, expectedK, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("case " + resCnt + ": Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.Key.ready.poke(false.B) + dut.io.Query.ready.poke(false.B) + } + dut.clock.step() + } + }.join() + } + private def testQKGenWithReg[T: Numeric: ClassTag]( + dut: QKGenWithReg + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + val caseNum = 10 + + val testCases = Array.tabulate(caseNum) { i => + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val exQuery = mmul(inputToken, weightQ) + val exKey = mmul(inputToken, weightK) + (inputToken, weightQ, weightK, exQuery, exKey) + } + + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inputToken, weightQ, weightK, _, _) = testCases(cnt) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("test case " + cnt + ": inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + cnt += 1 + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, _, exQuery, exKey) = testCases(resCnt) + if (dut.io.Key.valid.peekBoolean() && dut.io.Query.valid.peekBoolean()) { + dut.io.Key.ready.poke(true.B) + dut.io.Query.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until n + } { + val outQBigInt = dut.io.Query.bits(row)(col).peekInt() + val outKBigInt = dut.io.Key.bits(row)(col).peekInt() + val outQ = fromBinaryBigInt[T](outQBigInt) + val outK = fromBinaryBigInt[T](outKBigInt) + val expectedQ = exQuery(row)(col) + val expectedK = exKey(row)(col) + checkResult(outQ, expectedQ, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + checkResult(outK, expectedK, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("case " + resCnt + ": Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.Key.ready.poke(false.B) + dut.io.Query.ready.poke(false.B) + } + dut.clock.step() + } + }.join() + } + + private def testQKMul[T: Numeric: ClassTag]( + dut: QKMul + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + + val caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inQuery = matInit[T](m, n) + val inKey = matInit[T](m, n) + val exScores = mmul(inQuery, inKey.transpose) + (inQuery, inKey, exScores) + } + + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inQuery, inKey, _) = testCases(cnt) + if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { + println("test case " + cnt + ": Query and Key are ready") + dut.io.Query.valid.poke(true.B) + dut.io.Key.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + dut.io.Query.bits(row)(col).poke(toBinaryBigInt(inQuery(row)(col)).U) + dut.io.Key.bits(row)(col).poke(toBinaryBigInt(inKey(row)(col)).U) + } + cnt += 1 + } else { + dut.io.Query.valid.poke(false.B) + dut.io.Key.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exScores) = testCases(resCnt) + if (dut.io.scores.valid.peekBoolean()) { + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until m + } { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("case " + resCnt + ": Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.scores.ready.poke(false.B) + } + dut.clock.step() + } + }.join() + + } + + private def testQKMulWithReg[T: Numeric: ClassTag]( + dut: QKMulWithReg + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + + val caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inQuery = matInit[T](m, n) + val inKey = matInit[T](m, n) + val exScores = mmul(inQuery, inKey.transpose) + (inQuery, inKey, exScores) + } + + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inQuery, inKey, _) = testCases(cnt) + if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { + println("test case " + cnt + ": Query and Key are ready") + dut.io.Query.valid.poke(true.B) + dut.io.Key.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + dut.io.Query.bits(row)(col).poke(toBinaryBigInt(inQuery(row)(col)).U) + dut.io.Key.bits(row)(col).poke(toBinaryBigInt(inKey(row)(col)).U) + } + cnt += 1 + } else { + dut.io.Query.valid.poke(false.B) + dut.io.Key.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exScores) = testCases(resCnt) + if (dut.io.scores.valid.peekBoolean()) { + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until m + } { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("case " + resCnt + ": Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.scores.ready.poke(false.B) + } + dut.clock.step() + } + }.join() + + } + + private def testQKMulSingle[T: Numeric: ClassTag]( + dut: QKMulSingle + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val peCount = dut.peCount + val gemmType = dut.gemmType + + val caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inQuery = matInit[T](m, n) + val inKey = matInit[T](m, n) + val exScores = mmul(inQuery, inKey.transpose) + (inQuery, inKey, exScores) + } + + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inQuery, inKey, _) = testCases(cnt) + if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { + println("test case " + cnt + ": Query and Key are ready") + dut.io.Query.valid.poke(true.B) + dut.io.Key.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + dut.io.Query.bits(row)(col).poke(toBinaryBigInt(inQuery(row)(col)).U) + dut.io.Key.bits(row)(col).poke(toBinaryBigInt(inKey(row)(col)).U) + } + cnt += 1 + } else { + dut.io.Query.valid.poke(false.B) + dut.io.Key.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exScores) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curRowScores.valid.peekBoolean()) { + dut.io.curRowScores.ready.poke(true.B) + dut.io.curRowScores.ready.poke(true.B) + val curRowIndex = dut.io.curRowScores.bits.index.peekInt() + println(s"curRow index: $curRowIndex") + dut.io.curRowScores.bits.index.expect(rowIdx.U) + for { + col <- 0 until m + } { + val outBigInt = dut.io.curRowScores.bits.value(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(rowIdx)(col) + checkResult(out, expected, rowIdx, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curRowScores.ready.poke(false.B) + } + dut.clock.step() + } + dut.io.done.expect(true.B) + resCnt += 1 + } + }.join() + + } + + private def testAttnScores[T: Numeric: ClassTag]( + dut: AttnScores + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + var caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val exQuery = mmul(inputToken, weightQ) + val exKey = mmul(inputToken, weightK) + val exScores = mmul(exQuery, exKey.transpose) + (inputToken, weightQ, weightK, exScores) + } + + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inputToken, weightQ, weightK, _) = testCases(cnt) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("test case " + cnt + ": inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + cnt += 1 + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, _, exScores) = testCases(resCnt) + if (dut.io.scores.valid.peekBoolean()) { + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until m + } { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.scores.ready.poke(false.B) + } + dut.clock.step() + } + }.join() + } + + private def testAttnScoresSingle[T: Numeric: ClassTag]( + dut: AttnScoresSingle + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + var caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val exQuery = mmul(inputToken, weightQ) + val exKey = mmul(inputToken, weightK) + val exScores = mmul(exQuery, exKey.transpose) + (inputToken, weightQ, weightK, exScores) + } + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inputToken, weightQ, weightK, _) = testCases(cnt) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("test case " + cnt + ": inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + cnt += 1 + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, _, exScores) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curRowScores.valid.peekBoolean()) { + dut.io.curRowScores.ready.poke(true.B) + val curRowIndex = dut.io.curRowScores.bits.index.peekInt() + println(s"curRow index: $curRowIndex") + dut.io.curRowScores.bits.index.expect(rowIdx.U) + for { + col <- 0 until m + } { + val outBigInt = dut.io.curRowScores.bits.value(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(rowIdx)(col) + checkResult(out, expected, rowIdx, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curRowScores.ready.poke(false.B) + } + dut.clock.step() + } + dut.io.done.expect(true.B) + resCnt += 1 + } + }.join() + } + + private def testAttnScoresSingleQueue[T: Numeric: ClassTag]( + dut: AttnScoresSingleQueue + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + var caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val exQuery = mmul(inputToken, weightQ) + val exKey = mmul(inputToken, weightK) + val exScores = mmul(exQuery, exKey.transpose) + (inputToken, weightQ, weightK, exScores) + } + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inputToken, weightQ, weightK, _) = testCases(cnt) + dut.io.flush.poke(true.B) + dut.clock.step() + dut.io.flush.poke(false.B) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("test case " + cnt + ": inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + cnt += 1 + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, _, exScores) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curRowScores.valid.peekBoolean()) { + dut.io.curRowScores.ready.poke(true.B) + val curRowIndex = dut.io.curRowScores.bits.index.peekInt() + println(s"curRow index: $curRowIndex") + dut.io.curRowScores.bits.index.expect(rowIdx.U) + for { + col <- 0 until m + } { + val outBigInt = dut.io.curRowScores.bits.value(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(rowIdx)(col) + checkResult(out, expected, rowIdx, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curRowScores.ready.poke(false.B) + } + dut.clock.step() + } + dut.io.done.expect(true.B) + resCnt += 1 + } + }.join() + } + + "AttnScoresSingleQueue " should "compute fxp matrix multiplication" in { + implicit val config: DataWidthConfig = FxpConfig + test(new AttnScoresSingleQueue(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + testAttnScoresSingleQueue[Int](dut) + } + } + + // "AttnScoresSingleQueue " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new AttnScoresSingleQueue(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testAttnScoresSingleQueue[Float](dut) + // } + // } + + // "AttnScoresSingle " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new AttnScoresSingle(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testAttnScoresSingle[Int](dut) + // } + // } + + // "AttnScoresSingle " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new AttnScoresSingle(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testAttnScoresSingle[Float](dut) + // } + // } + + // "QKMulSingle " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKMulSingle(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKMulSingle[Int](dut) + // } + // } + + // "AttnScores " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new AttnScores(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testAttnScores[Int](dut) + // } + // } + + // "AttnScores " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new AttnScores(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testAttnScores[Float](dut) + // } + // } + + // "QKMul " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKMul(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKMul[Int](dut) + // } + // } + + // "QKMulWithReg " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKMulWithReg(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKMulWithReg[Int](dut) + // } + // } + + // "QKGen " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKGen(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKGen[Int](dut) + // } + // } + + // "QKGenWithReg " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKGenWithReg(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKGenWithReg[Int](dut) + // } + // } +} diff --git a/src/test/scala/kernel/alu/GemmFMATest.scala b/src/test/scala/kernel/alu/GemmFMATest.scala new file mode 100644 index 0000000..c017c42 --- /dev/null +++ b/src/test/scala/kernel/alu/GemmFMATest.scala @@ -0,0 +1,444 @@ +package kernel.alu + +import chisel3._ +import chiseltest._ +import chisel3.util.DecoupledIO +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.ParallelTestExecution +import scala.reflect.ClassTag +import kernel.alu.{DataWidthConfig, Fp32Config, Fp64Config, FxpConfig, GEMMDataType} +import ujson.Arr +import Utils._ + +class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestExecution { + + private def testMultiFMA[T: Numeric: ClassTag]( + dut: MultiFMA + )( + implicit config: DataWidthConfig + ): Unit = { + val k = dut.k + val peCount = dut.peCount + val gemmType = dut.gemmType + + val matrixA_row = matInit[T](1, k) + val matrixB_cols = matInit[T](k, peCount) + val expectedResults = mmul(matrixA_row, matrixB_cols) + + printmat(matrixA_row) + printmat(matrixB_cols) + printmat(expectedResults) + + dut.io.reset.poke(true.B) + dut.clock.step(1) + dut.io.reset.poke(false.B) + + if (dut.io.matrixA_row.ready.peekBoolean() && dut.io.matrixB_cols.ready.peekBoolean()) { + println("matrixA_row and matrixB_cols are ready") + dut.io.matrixA_row.valid.poke(true.B) + dut.io.matrixB_cols.valid.poke(true.B) + + for { + i <- matrixA_row(0).indices + j <- 0 until peCount + } { + dut.io.matrixA_row.bits(i).poke(toBinaryBigInt(matrixA_row(0)(i)).U) + dut.io.matrixB_cols.bits(i)(j).poke(toBinaryBigInt(matrixB_cols(i)(j)).U) + } + } else { + dut.io.matrixA_row.valid.poke(false.B) + dut.io.matrixB_cols.valid.poke(false.B) + } + + while (!dut.io.blockResult.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.blockResult.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + + for (i <- 0 until peCount) { + val outBigInt = dut.io.blockResult.bits(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(0)(i) + checkResult(out, expected, 0, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + private def testMultiFMA_v2[T: Numeric: ClassTag]( + dut: MultiFMA_v2 + )( + implicit config: DataWidthConfig + ): Unit = { + val k = dut.k + val peCount = dut.peCount + val gemmType = dut.gemmType + + val matrixA_row = matInit[T](1, k) + val matrixB_cols = matInit[T](k, peCount) + val expectedResults = mmul(matrixA_row, matrixB_cols) + + printmat(matrixA_row) + printmat(matrixB_cols) + printmat(expectedResults) + + dut.io.reset.poke(true.B) + dut.clock.step(1) + dut.io.reset.poke(false.B) + + if (dut.io.matrixA_row.ready.peekBoolean() && dut.io.matrixB_cols.ready.peekBoolean()) { + println("matrixA_row and matrixB_cols are ready") + dut.io.matrixA_row.valid.poke(true.B) + dut.io.matrixB_cols.valid.poke(true.B) + + for { + i <- matrixA_row(0).indices + j <- 0 until peCount + } { + dut.io.matrixA_row.bits(i).poke(toBinaryBigInt(matrixA_row(0)(i)).U) + dut.io.matrixB_cols.bits(i)(j).poke(toBinaryBigInt(matrixB_cols(i)(j)).U) + } + } else { + dut.io.matrixA_row.valid.poke(false.B) + dut.io.matrixB_cols.valid.poke(false.B) + } + + while (!dut.io.blockResult.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.blockResult.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + + for (i <- 0 until peCount) { + val outBigInt = dut.io.blockResult.bits(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(0)(i) + checkResult(out, expected, 0, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testGEMMFMATotal[T: Numeric: ClassTag]( + dut: GEMMFMATotal + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + val caseNum = 10 + + val testCases = Array.tabulate(caseNum) { i => + val matrixA = matInit[T](m, k) + val matrixB = matInit[T](k, n) + val exResult = mmul(matrixA, matrixB) + (matrixA, matrixB, exResult) + } + + fork { + var cnt = 0 + while (cnt < caseNum) { + val (matrixA, matrixB, _) = testCases(cnt) + if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { + println(s"case $cnt : matrixA and matrixB are ready") + dut.io.matrixA.valid.poke(true.B) + dut.io.matrixB.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) + } + cnt += 1 + } else { + dut.io.matrixA.valid.poke(false.B) + dut.io.matrixB.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exResult) = testCases(resCnt) + if (dut.io.results.valid.peekBoolean()) { + dut.io.results.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until n + } { + val outBigInt = dut.io.results.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exResult(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt Verification passed!") + else println(s"case $resCnt Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.results.ready.poke(false.B) + } + dut.clock.step() + } + }.join() + } + + private def testGEMMFMASingle[T: Numeric: ClassTag]( + dut: GEMMFMASingle + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val peCount = dut.peCount + val gemmType = dut.gemmType + + val caseNum = 10 + + val testCases = Array.tabulate(caseNum) { i => + val matrixA = matInit[T](m, k) + val matrixB = matInit[T](k, n) + val exResult = mmul(matrixA, matrixB) + (matrixA, matrixB, exResult) + } + + fork { + var cnt = 0 + while (cnt < caseNum) { + val (matrixA, matrixB, _) = testCases(cnt) + if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { + println(s"case $cnt : matrixA and matrixB are ready") + dut.io.matrixA.valid.poke(true.B) + dut.io.matrixB.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) + } + cnt += 1 + } else { + dut.io.matrixA.valid.poke(false.B) + dut.io.matrixB.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exResult) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curRow.valid.peekBoolean()) { + dut.io.curRow.ready.poke(true.B) + val curRowIndex = dut.io.curRow.bits.index.peekInt() + dut.io.curRow.bits.index.expect(rowIdx.U) + println(s"curRow index: $curRowIndex") + for (i <- 0 until n) { + val outBigInt = dut.io.curRow.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exResult(curRowIndex.toInt)(i) + checkResult(out, expected, rowIdx, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curRow.ready.poke(false.B) + } + dut.clock.step() + } + dut.io.done.expect(true.B) + resCnt += 1 + } + }.join() + } + + private def testGEMMSingleQueue[T: Numeric: ClassTag]( + dut: GEMMSingleQueue + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + val caseNum = 10 + + val testCases = Array.tabulate(caseNum) { i => + val matrixA = matInit[T](m, k) + val matrixB = matInit[T](k, n) + val exResult = mmul(matrixA, matrixB) + (matrixA, matrixB, exResult) + } + fork { + var cnt = 0 + while (cnt < caseNum) { + dut.io.flush.poke(true.B) + dut.clock.step() + dut.io.flush.poke(false.B) + val (matrixA, matrixB, _) = testCases(cnt) + if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { + println(s"case $cnt : matrixA and matrixB are ready") + dut.io.matrixA.valid.poke(true.B) + dut.io.matrixB.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) + } + cnt += 1 + } else { + dut.io.matrixA.valid.poke(false.B) + dut.io.matrixB.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exResult) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curRow.valid.peekBoolean()) { + dut.io.curRow.ready.poke(true.B) + val curRowIndex = dut.io.curRow.bits.index.peekInt() + dut.io.curRow.bits.index.expect(rowIdx.U) + println(s"curRow index: $curRowIndex") + for (i <- 0 until n) { + val outBigInt = dut.io.curRow.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exResult(curRowIndex.toInt)(i) + checkResult(out, expected, rowIdx, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curRow.ready.poke(false.B) + } + dut.clock.step() + } + dut.io.done.expect(true.B) + resCnt += 1 + } + }.join() + } + + "GEMMSingleQueue " should "compute fxp matrix multiplication" in { + implicit val config: DataWidthConfig = FxpConfig + test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + testGEMMSingleQueue[Int](dut) + } + } + + // "GEMMSingleQueue " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMSingleQueue[Float](dut) + // } + // } + + // "GEMMFMATotal " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new GEMMFMATotal(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMATotal[Int](dut) + // } + // } + + // "GEMMFMATotal " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new GEMMFMATotal(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMATotal[Float](dut) + // } + // } + + // "GEMMFMASingle " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new GEMMFMASingle(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMASingle[Float](dut) + // } + // } + + // "GEMMFMASingle " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new GEMMFMASingle(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMASingle[Int](dut) + // } + // } + + // "MultiFMA " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new MultiFMA(k = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testMultiFMA[Int](dut) + // } + // } + + // "MultiFMA " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new MultiFMA( k = 4, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testMultiFMA[Float](dut) + // } + // } + // "MultiFMA_v2 " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new MultiFMA_v2(k = 4, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testMultiFMA_v2[Float](dut) + // } + // } + // "MultiFMA_v2 " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new MultiFMA_v2(k = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testMultiFMA_v2[Int](dut) + // } + // } +} diff --git a/src/test/scala/kernel/alu/OutValueTest.scala b/src/test/scala/kernel/alu/OutValueTest.scala new file mode 100644 index 0000000..f6e8481 --- /dev/null +++ b/src/test/scala/kernel/alu/OutValueTest.scala @@ -0,0 +1,191 @@ +package kernel.alu + +import chisel3._ +import chiseltest._ +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.ParallelTestExecution +import scala.reflect.ClassTag +import kernel.alu.{DataWidthConfig, Fp32Config, Fp64Config, FxpConfig, GEMMDataType} +import ujson.Arr +import Utils._ + +class OutValueTest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestExecution { + + private def testOutValue[T: Numeric: ClassTag]( + dut: OutValue + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + + var caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inScores = matInit[T](m, m) + val inValue = matInit[T](m, n) + val exAttnOut = mmul(inScores, inValue) + (inScores, inValue, exAttnOut) + } + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inScores, inValue, _) = testCases(cnt) + if (dut.io.Scores.ready.peekBoolean() && dut.io.Value.ready.peekBoolean()) { + println("test case " + cnt + ": Scores and Value are ready") + dut.io.Scores.valid.poke(true.B) + dut.io.Value.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until m + } { + dut.io.Scores.bits(row)(i).poke(toBinaryBigInt(inScores(row)(i)).U) + dut.io.Value.bits(i)(col).poke(toBinaryBigInt(inValue(i)(col)).U) + } + cnt += 1 + } else { + dut.io.Scores.valid.poke(false.B) + dut.io.Value.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exAttnOut) = testCases(resCnt) + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.AttnOut.valid.peekBoolean()) { + dut.io.AttnOut.ready.poke(true.B) + val (_, _, exAttnOut) = testCases(resCnt) + for { + row <- 0 until m + col <- 0 until n + } { + val outBigInt = dut.io.AttnOut.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exAttnOut(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("case " + resCnt + ": Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.AttnOut.ready.poke(false.B) + } + dut.clock.step() + } + }.join() + } + + private def testOutValueSingle[T: Numeric: ClassTag]( + dut: OutValueSingle + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + + var caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inScores = matInit[T](m, m) + val inValue = matInit[T](m, n) + val exAttnOut = mmul(inScores, inValue) + (inScores, inValue, exAttnOut) + } + + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inScores, inValue, _) = testCases(cnt) + if (dut.io.Value.ready.peekBoolean()) { + println("test case " + cnt + ": Value is ready") + dut.io.Value.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + dut.io.Value.bits(row)(col).poke(toBinaryBigInt(inValue(row)(col)).U) + } + cnt += 1 + } else { + dut.io.Value.valid.poke(false.B) + } + var rowIdx = 0 + while (rowIdx < m) { + if (dut.io.curScores.ready.peekBoolean()) { + println(s"case ${cnt-1} curScores index: $rowIdx is ready") + dut.io.curScores.valid.poke(true.B) + for (i <- 0 until m) { + dut.io.curScores.bits.value(i).poke(toBinaryBigInt(inScores(rowIdx)(i)).U) + } + dut.io.curScores.bits.index.poke(rowIdx.U) + rowIdx += 1 + } else { + dut.io.curScores.valid.poke(false.B) + } + dut.clock.step() + } + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exAttnOut) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curAttnOut.valid.peekBoolean()) { + dut.io.curAttnOut.ready.poke(true.B) + val curRowIndex = dut.io.curAttnOut.bits.index.peekInt() + dut.io.curAttnOut.bits.index.expect(rowIdx.U) + for (i <- 0 until n) { + val outBigInt = dut.io.curAttnOut.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exAttnOut(rowIdx)(i) + val precision = 0.001f + checkResult(out, expected, rowIdx, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curAttnOut.ready.poke(false.B) + } + dut.clock.step() + } + dut.clock.step() + dut.io.done.expect(true.B) + println(s"case $resCnt done\n") + resCnt += 1 + } + }.join() + } + + //TODO: Fix the warning in the OutValueSingle module + // warning: This module has an additional loop for the curScores input. + // The error might be caused by the valid signal of the input"curScores", but it hasn't been resolved yet. + "OutValueSingle " should "compute fxp matrix multiplication" in { + implicit val config: DataWidthConfig = FxpConfig + test(new OutValueSingle(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + testOutValueSingle[Int](dut) + } + } + + // "OutValue " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new OutValue(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testOutValue[Int](dut) + // } + // } +} diff --git a/src/test/scala/kernel/alu/utils.scala b/src/test/scala/kernel/alu/utils.scala new file mode 100644 index 0000000..c61af5f --- /dev/null +++ b/src/test/scala/kernel/alu/utils.scala @@ -0,0 +1,214 @@ +package kernel.alu +import chisel3._ +import chiseltest._ +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.ParallelTestExecution +import scala.reflect.ClassTag +import kernel.alu.{DataWidthConfig, Fp32Config, Fp64Config, FxpConfig, GEMMDataType} +import ujson.Arr + +object Utils { + def mmul[T: Numeric: ClassTag](a: Array[Array[T]], b: Array[Array[T]]): Array[Array[T]] = { + val rows = a.length + val cols = b(0).length + val n = b.length + val num = implicitly[Numeric[T]] + + Array.tabulate(rows, cols) { (i, j) => + var sum = num.zero + for (k <- 0 until n) { + sum = num.plus(sum, num.times(a(i)(k), b(k)(j))) + } + sum + } + } + + def matInit[T: Numeric: ClassTag]( + rows: Int, + cols: Int + )( + implicit config: DataWidthConfig + ): Array[Array[T]] = { + val r = new scala.util.Random(42) + val ct = implicitly[ClassTag[T]] + val numeric = implicitly[Numeric[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + // 定点数使用 -8 到 7 的整数 + Array.fill(rows, cols)( + numeric.fromInt( + // r.nextInt(math.pow(2, config.inputWidth).toInt) - math.pow(2, config.inputWidth - 1).toInt + r.nextInt(4) - 2 + ) + ) + case c if c == classOf[Float] => + // 32位浮点数使用 -1 到 1 的随机浮点数 + // Float 类型 + Array.fill(rows, cols)((r.nextFloat() * 2 - 1).asInstanceOf[T]) + case c if c == classOf[Double] => + // 64位浮点数使用 -1 到 1 的随机浮点数 + Array.fill(rows, cols)((r.nextDouble() * 2 - 1).asInstanceOf[T]) + case _ => + throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") + } + } + + def toSignedBigInt(value: BigInt, width: Int): BigInt = { + val signBit = (value >> (width - 1)) & 1 + + if (signBit == 1) { + val maxValue = BigInt(1) << width + value - maxValue + } else { + value + } + } + + def printmat[T: Numeric: ClassTag](m: Array[Array[T]]): Unit = { + val numeric = implicitly[Numeric[T]] + val ct = implicitly[ClassTag[T]] + + m.foreach { r => + r.foreach { v => + ct.runtimeClass match { + case c if c == classOf[Float] => + print(f"${v.asInstanceOf[Float]}%.4f\t") + case c if c == classOf[Double] => + print(f"${v.asInstanceOf[Double]}%.4f\t") + case c if c == classOf[Int] => + print(f"${v.asInstanceOf[Int]}%d\t") + case _ => + throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") + } + } + println(";") + } + println() + } + + def printmat[T: Numeric: ClassTag](m: Array[T], x: Int, y: Int)(implicit config: DataWidthConfig): Unit = { + val numeric = implicitly[Numeric[T]] + val ct = implicitly[ClassTag[T]] + + for (i <- 0 until x) { + for (j <- 0 until y) { + ct.runtimeClass match { + case c if c == classOf[Float] => + print(f"${m(i * y + j).asInstanceOf[Float]}%.4f\t") + case c if c == classOf[Double] => + print(f"${m(i * y + j).asInstanceOf[Double]}%.4f\t") + case c if c == classOf[Int] => + print(f"${m(i * y + j).asInstanceOf[Int]}%d\t") + case c if c == classOf[BigInt] => + print(f"${toSignedBigInt(m(i * y + j).asInstanceOf[BigInt], config.inputWidth)}%d\t") + case _ => + throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") + } + } + println(";") + } + println() + } + + // convert T to binary bigInt + def toBinaryBigInt[T: Numeric: ClassTag](v: T)(implicit config: DataWidthConfig): BigInt = { + val ct = implicitly[ClassTag[T]] + val num = implicitly[Numeric[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + val intValue = v.asInstanceOf[Int] + // 使用 inputWidth 位来表示所有整数,保持符号位 + val mask = (1L << config.inputWidth) - 1 + BigInt(intValue) & mask + case c if c == classOf[Float] => + BigInt(java.lang.Float.floatToRawIntBits(v.asInstanceOf[Float]).toBinaryString, 2) + case c if c == classOf[Double] => + BigInt(java.lang.Double.doubleToRawLongBits(v.asInstanceOf[Double]).toBinaryString, 2) + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") + } + } + + // convrt T to binary string + def toBinaryString[T: Numeric: ClassTag](v: T)(implicit config: DataWidthConfig): String = { + val ct = implicitly[ClassTag[T]] + val num = implicitly[Numeric[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + val intBValue = v.asInstanceOf[Int].toBinaryString + if (intBValue.length < config.inputWidth) { + intBValue.reverse.padTo(config.inputWidth, '0').reverse + } else { + intBValue.takeRight(config.inputWidth) + } + case c if c == classOf[Float] => + java.lang.Float + .floatToRawIntBits(v.asInstanceOf[Float]) + .toBinaryString + .reverse + .padTo(config.inputWidth, '0') + .reverse + case c if c == classOf[Double] => + java.lang.Double + .doubleToRawLongBits(v.asInstanceOf[Double]) + .toBinaryString + .reverse + .padTo(config.inputWidth, '0') + .reverse + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") + } + } + + // convert binary bigInt to T + def fromBinaryBigInt[T: Numeric: ClassTag](bigInt: BigInt)(implicit config: DataWidthConfig): T = { + val ct = implicitly[ClassTag[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + val intValue = bigInt.toInt + // 处理符号位 + val signExtendedValue = if ((intValue & (1 << (config.inputWidth - 1))) != 0) { + intValue | ~((1 << config.inputWidth) - 1) + } else { + intValue + } + signExtendedValue.asInstanceOf[T] + case c if c == classOf[Float] => + java.lang.Float.intBitsToFloat(bigInt.toInt).asInstanceOf[T] + case c if c == classOf[Double] => + java.lang.Double.longBitsToDouble(bigInt.toLong).asInstanceOf[T] + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") + } + } + + def checkResult[T: Numeric: ClassTag]( + out: T, + expected: T, + row: Int, + col: Int, + precision: Float + ): Option[Unit] = { + val isInvalid = implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + } + + if (isInvalid) { + println(s"Error: row: $row col: $col") + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + Some(()) + } else None + } +}