From 08c6a3fe6820e0a30eb812e97f4207adc30543dc Mon Sep 17 00:00:00 2001 From: pyfirstcsh <8295488+cao-shuai-hu@user.noreply.gitee.com> Date: Tue, 24 Dec 2024 22:18:21 +0800 Subject: [PATCH 01/10] add:GemmFMA&&Test --- src/main/scala/kernel/alu/GemmFMA.scala | 553 ++++++++++++++++++ src/test/scala/kernel/alu/GemmFMATest.scala | 609 ++++++++++++++++++++ 2 files changed, 1162 insertions(+) create mode 100644 src/main/scala/kernel/alu/GemmFMA.scala create mode 100644 src/test/scala/kernel/alu/GemmFMATest.scala diff --git a/src/main/scala/kernel/alu/GemmFMA.scala b/src/main/scala/kernel/alu/GemmFMA.scala new file mode 100644 index 0000000..f6cf4fc --- /dev/null +++ b/src/main/scala/kernel/alu/GemmFMA.scala @@ -0,0 +1,553 @@ +package kernel.alu + +import chisel3._ +import chisel3.util._ +import kernel.alu.GEMMDataType +import kernel.alu.DataWidthConfig +import kernel.utils.DebugLog + +class currentRowIndex( + val m: Int, + val n: Int +)( + implicit config: DataWidthConfig) + extends Bundle { + val index = Output(UInt(log2Ceil(m).W)) //输出的行索引 + val value = Output(Vec(n, UInt(config.outputWidth.W))) //输出的行值 +} + +class MultiFMAMM( + val k: Int = 64, + val PECount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val matrixA_row = Input(Vec(k, UInt(config.inputWidth.W))) // 矩阵1的一行 + val matrixB_cols = Input(Vec(PECount, Vec(k, UInt(config.inputWidth.W)))) // 矩阵2的PECount列向量 + val results = Output(Vec(PECount, UInt(config.outputWidth.W))) // 点积结果 + val valids = Output(Vec(PECount, Bool())) // 结果有效标志 + val reset = Input(Bool()) + }) + + // 创建 PECount 个 PE 实例 + val pes = Seq.fill(PECount)(gemmType match { + case GEMMDataType.Fxp => Module(new PEFxp()).io + case GEMMDataType.Fp32 => Module(new PEFp()).io + case GEMMDataType.Fp64 => Module(new PEFp()).io + case _ => throw new IllegalArgumentException("Unsupported GEMM type") + }) + + // 当前索引寄存器 + val index = RegInit(0.U(log2Ceil(k).W)) + // 结果有效标志 + val valid = RegInit(false.B) + + // 连接每个 PE 的输入和输出 + for (i <- 0 until PECount) { + val pe = pes(i) + pe.reset := io.reset + pe.in_h := io.matrixA_row(index) // 矩阵1的当前行值 + pe.in_v := io.matrixB_cols(i)(index) // 矩阵2的第i列当前值 + io.results(i) := pe.out // 输出结果 + io.valids(i) := valid // 当点积完成时置位 valid + } + + // 索引、结果有效位控制逻辑 + when(io.reset) { + index := 0.U + valid := false.B + }.elsewhen(index =/= (k - 1).U) { + index := index + 1.U + valid := false.B + }.elsewhen(index === (k - 1).U) { + index := 0.U + valid := true.B + } +} + +class GEMMFMA( + val m: Int, + val k: Int, + val n: Int, + val PECount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + // require(m % PECount == 0 && k % PECount == 0 && n % PECount == 0, "Matrix dimensions must be divisible by PECount") + val io = IO(new Bundle { + val matrixA = Input(Vec(m, Vec(k, UInt(config.inputWidth.W)))) // 矩阵A + val matrixB = Input(Vec(k, Vec(n, UInt(config.inputWidth.W)))) // 矩阵B + val results = Output(Vec(m, Vec(n, UInt(config.outputWidth.W)))) // 结果矩阵 + val done = Output(Bool()) // 完成标志 + }) + + val rowIndex = RegInit(0.U(log2Ceil(m).W)) // 当前行索引 + val colIndex = RegInit(0.U(log2Ceil(n).W)) // 当前列块索引 + val resultMatrix = Reg(Vec(m, Vec(n, UInt(config.outputWidth.W)))) // 存储结果 + val doneFlag = RegInit(false.B) // 完成标志 + val resetFlag = RegInit(false.B) // 复位标志 + + // 实例化MultiFMA + val multiFMA = Module(new MultiFMAMM(k, PECount, gemmType)) + + // 输入连接 + multiFMA.io.matrixA_row := io.matrixA(rowIndex) + multiFMA.io.matrixB_cols := VecInit(Seq.tabulate(PECount) { i => + VecInit(io.matrixB.map(_(colIndex + i.U))) + }) + + // 一块处理结束, 结果存储, 更新块索引 + when(multiFMA.io.valids.reduce(_ && _)) { + for (i <- 0 until PECount) { + resultMatrix(rowIndex)(colIndex + i.U) := multiFMA.io.results(i) + } + resetFlag := true.B + // 更新块索引 + when(colIndex === (n - PECount).U) { + colIndex := 0.U + when(rowIndex === (m - 1).U) { + rowIndex := 0.U + doneFlag := true.B + }.otherwise { + rowIndex := rowIndex + 1.U + } + }.otherwise { + colIndex := colIndex + PECount.U + } + } + + // 复位控制,也即清空累加器 + when(resetFlag) { + multiFMA.io.reset := true.B + resetFlag := false.B + }.otherwise { + multiFMA.io.reset := false.B + } + + io.done := doneFlag + io.results := resultMatrix +} + +class GEMMFMATotal( + val m: Int, + val k: Int, + val n: Int, + val PECount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + require(m % PECount == 0 && k % PECount == 0 && n % PECount == 0, "Matrix dimensions must be divisible by PECount") + val io = IO(new Bundle { + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) // 矩阵A + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) // 矩阵B + val results = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) // 结果矩阵 + }) + + val rowIndex = Counter(m) + val colIndex = Counter(n / PECount) + + val resultsReg = Reg(Vec(m, Vec(n, UInt(config.outputWidth.W)))) // 存储结果 + val resetFlag = RegInit(false.B) // 复位标志,用于清空MultiFMA + + val dataValid = io.matrixA.valid && io.matrixB.valid + val readyReg = RegInit(true.B) + io.matrixA.ready := readyReg + io.matrixB.ready := readyReg + + io.results.valid := false.B + io.results.bits := DontCare + + // 实例化MultiFMA + val multiFMA = Module(new MultiFMAMM(k, PECount, gemmType)) + + // 输入连接 + multiFMA.io.matrixA_row := io.matrixA.bits(rowIndex.value) + multiFMA.io.matrixB_cols := VecInit(Seq.tabulate(PECount) { i => + VecInit(io.matrixB.bits.map(_((colIndex.value * PECount.U + i.U) % n.U))) + }) + + // 状态机定义 + object state extends ChiselEnum { + val idle, compute, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.compute + } + } + + is(state.compute) { + when(multiFMA.io.valids.reduce(_ && _)) { + for (i <- 0 until PECount) { + resultsReg(rowIndex.value)((colIndex.value * PECount.U + i.U) % n.U) := multiFMA.io.results(i) + } + resetFlag := true.B + stateReg := state.update + } + } + + is(state.update) { + when(colIndex.inc()) { + when(rowIndex.inc()) { + stateReg := state.done + }.otherwise { + stateReg := state.compute + } + }.otherwise { + stateReg := state.compute + } + + } + is(state.done) { + readyReg := true.B + io.results.valid := true.B + io.results.bits := resultsReg + stateReg := state.idle + } + } + when(resetFlag) { + multiFMA.io.reset := true.B + resetFlag := false.B + }.otherwise { + multiFMA.io.reset := false.B + } + +} +class GEMMFMASingle( + val m: Int, + val k: Int, + val n: Int, + val PECount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + require(m % PECount == 0 && k % PECount == 0 && n % PECount == 0, "Matrix dimensions must be divisible by PECount") + val io = IO(new Bundle { + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) // 矩阵A + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) // 矩阵B + val currentRow = Decoupled(new currentRowIndex(m, n)) //输出的行索引 + val done = Output(Bool()) // 整个矩阵完成标志 + }) + + val rowIndex = Counter(m) + val colIndex = Counter(n / PECount) + + val currentRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) // 存储当前行结果 + val doneFlag = RegInit(false.B) // 完成标志 + val resetFlag = RegInit(false.B) // 复位标志 + + val dataValid = io.matrixA.valid && io.matrixB.valid + val readyReg = RegInit(true.B) + io.matrixA.ready := readyReg + io.matrixB.ready := readyReg + + io.currentRow.valid := false.B + io.currentRow.bits := DontCare + + // 实例化MultiFMA + val multiFMA = Module(new MultiFMAMM(k, PECount, gemmType)) + + // 输入连接 + multiFMA.io.matrixA_row := io.matrixA.bits(rowIndex.value) + multiFMA.io.matrixB_cols := VecInit(Seq.tabulate(PECount) { i => + VecInit(io.matrixB.bits.map(_((colIndex.value * PECount.U + i.U) % n.U))) + }) + + // 状态机定义 + object state extends ChiselEnum { + val idle, compute, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.compute + } + } + + is(state.compute) { + when(multiFMA.io.valids.reduce(_ && _)) { + for (i <- 0 until PECount) { + currentRowReg((colIndex.value * PECount.U + i.U) % n.U) := multiFMA.io.results(i) + } + resetFlag := true.B + stateReg := state.update + } + } + + is(state.update) { + io.currentRow.valid := false.B + when(colIndex.inc()) { + io.currentRow.valid := true.B + io.currentRow.bits.index := rowIndex.value + io.currentRow.bits.value := currentRowReg + when(rowIndex.inc()) { + stateReg := state.done + }.otherwise { + stateReg := state.compute + } + }.otherwise { + stateReg := state.compute + } + + } + is(state.done) { + doneFlag := true.B + readyReg := true.B + stateReg := state.idle + } + } + when(resetFlag) { + multiFMA.io.reset := true.B + resetFlag := false.B + }.otherwise { + multiFMA.io.reset := false.B + } + io.done := doneFlag +} + +// TODO: 优化,bug +class QKMulFMA( + val m: Int, + val k: Int, + val n: Int, + val PECount1: Int = 16, + val PECount2: Int = 16, + val gemmType: GEMMDataType.Type, + val bufferSizeGemm: Int = 32 +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + + class QKGenderWarper( + val m: Int, + val k: Int, + val n: Int, + val PECount: Int = 16, + val gemmType: GEMMDataType.Type, + val bufferSize: Int + )( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) // 矩阵A + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) // 矩阵B + val flush = Input(Bool()) + val outMatrix = Decoupled(new currentRowIndex(m, n)) + }) + + val qkGenMul = Module(new GEMMFMASingle(m, k, n, PECount, gemmType)) + io.matrixA <> qkGenMul.io.matrixA + io.matrixB <> qkGenMul.io.matrixB + + val currentBuffer = Module( + new Queue( + new currentRowIndex(m, n), + entries = bufferSize, + pipe = true, + flow = false, + useSyncReadMem = false, + hasFlush = true + ) + ) + + // hasFlush must be true + currentBuffer.io.flush.get := io.flush + + // ATTENTION: we assert the size of the buffer is huge enough to hold the current systolic group output + // we ignore the ready signal of the enq + currentBuffer.io.enq.bits := qkGenMul.io.currentRow.bits + currentBuffer.io.enq.valid := qkGenMul.io.currentRow.valid + + io.outMatrix <> currentBuffer.io.deq + } + + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val score = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + val resetBuffer = Input(Bool()) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + val readyReg = RegInit(true.B) + io.inputToken.ready := readyReg + io.weightQ.ready := readyReg + io.weightK.ready := readyReg + + // QKGen,Q: m * n, K: m * n + val qGen = Module(new QKGenderWarper(m, k, n, PECount1, gemmType, bufferSizeGemm)) + val kGen = Module(new QKGenderWarper(m, k, n, PECount1, gemmType, bufferSizeGemm)) + + qGen.io.matrixA <> io.inputToken + qGen.io.matrixB <> io.weightQ + kGen.io.matrixA <> io.inputToken + kGen.io.matrixB <> io.weightK + + qGen.io.flush := io.resetBuffer + kGen.io.flush := io.resetBuffer + + // // QKMul Q*K^T, Q: m * n, K: m * n -> m * m + // val Qrow = qGen.io.outMatrix.bits.value // one row of Q: 1 * n + // val Krow = kGen.io.outMatrix.bits.value // one row of K: 1 * n + // val QIndex = qGen.io.outMatrix.bits.index // the index of Q row + // val KIndex = kGen.io.outMatrix.bits.index // the index of K row + + // 创建一个 MultiFMAMM 模块来计算 Q 的一行和 K 的多列的乘积结果 中间维度为n + // val multiFMA = Module(new MultiFMAMM(n, PECount2, gemmType)) + + val qQueue = Module(new Queue(new currentRowIndex(m, n), bufferSizeGemm)) + val kQueue = Module(new Queue(new currentRowIndex(m, n), bufferSizeGemm)) + + // 将生成的每一行数据存储到队列中 + qQueue.io.enq.bits := qGen.io.outMatrix.bits + qQueue.io.enq.valid := qGen.io.outMatrix.valid + kQueue.io.enq.bits := kGen.io.outMatrix.bits + kQueue.io.enq.valid := kGen.io.outMatrix.valid +// 创建一个 M*N 的寄存器组来保存所有的 Q 和 K 值 + val qMatrix = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + val k_TMatrix = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + + // 状态机定义 + object state extends ChiselEnum { + val idle, load, compute, done = Value + } + val stateReg = RegInit(state.idle) + + // 计数器,用于跟踪 Q 和 K 的行数 + val qCounter = RegInit(0.U(log2Ceil(m).W)) + val kCounter = RegInit(0.U(log2Ceil(m).W)) + + // 创建一个 M*M 的寄存器组来保存所有的结果 + val scoreValue = RegInit(VecInit(Seq.fill(m)(VecInit(Seq.fill(m)(0.U(config.outputWidth.W)))))) + + switch(stateReg) { + is(state.idle) { + when(qQueue.io.enq.valid && kQueue.io.enq.valid) { + stateReg := state.load + } + } + + is(state.load) { + when(qQueue.io.deq.valid && kQueue.io.deq.valid) { + qMatrix(qQueue.io.deq.bits.index) := qQueue.io.deq.bits.value + for (i <- 0 until n) { + k_TMatrix(i)(kQueue.io.deq.bits.index) := kQueue.io.deq.bits.value(i) // 将 K 的值存储到转置后的 kMatrix 中 + } + qCounter := qCounter + 1.U + kCounter := kCounter + 1.U + qQueue.io.deq.ready := true.B + kQueue.io.deq.ready := true.B + when(qCounter === (m - 1).U && kCounter === (m - 1).U) { + stateReg := state.compute + } + } + } + + is(state.compute) { + val multiFMA = Module(new GEMMFMA(m, n, m, PECount2, gemmType)) + multiFMA.io.matrixA := qMatrix + multiFMA.io.matrixB := k_TMatrix + io.score.bits := multiFMA.io.results + io.score.valid := multiFMA.io.done + stateReg := state.done + } + + is(state.done) { + // 完成标志 + io.score.valid := true.B + stateReg := state.idle + } + } + + when(io.resetBuffer) { + qCounter := 0.U + kCounter := 0.U + scoreValue := VecInit(Seq.fill(m)(VecInit(Seq.fill(m)(0.U(config.outputWidth.W))))) + stateReg := state.idle + } + // // 创建一个 M*N 的寄存器组来保存所有的 K 值 + // val kMatrix = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + + // // 从队列中提取 Q 的一行和 K 的多列 + // val qRowFromQueue = qQueue.io.deq.bits.value + // val qIndexFromQueue = qQueue.io.deq.bits.index + + // val kColsFromQueue = Reg(Vec(PECount2, Vec(n, UInt(config.inputWidth.W)))) + // val kIndexFromQueue = Reg(Vec(PECount2, UInt(log2Ceil(m).W))) + + // // 计数器,用于跟踪 K 的列数 + // val kCounter = RegInit(0.U(log2Ceil(n / PECount2).W)) + + // // 当 K 队列中有足够的列时,提取 K 的多列 + // when(kQueue.io.deq.valid && kCounter < PECount2.U) { + // kColsFromQueue(kCounter) := kQueue.io.deq.bits.value + // kIndexFromQueue(kCounter) := kQueue.io.deq.bits.index + // kMatrix(kQueue.io.deq.bits.index) := kQueue.io.deq.bits.value + // kCounter := kCounter + 1.U + // kQueue.io.deq.ready := true.B + // }.otherwise { + // kQueue.io.deq.ready := false.B + // } + + // // 当 K 队列中有足够的列时,进行矩阵乘法 + // when(kCounter === PECount2.U) { + // multiFMA.io.matrixA_row := qRowFromQueue + // multiFMA.io.matrixB_cols := kColsFromQueue + // kCounter := 0.U + // } + + // // 连接结果和有效标志 + // val scoreValue = RegInit(VecInit(Seq.fill(m)(VecInit(Seq.fill(m)(0.U(config.outputWidth.W)))))) + // for (i <- 0 until PECount2) { + // when(multiFMA.io.valids(i)) { + // scoreValue(qIndexFromQueue)(kIndexFromQueue(i)) := multiFMA.io.results(i) + // } + // } + + // io.score.bits := scoreValue + // io.score.valid := qQueue.io.deq.valid && kQueue.io.deq.valid + + // // 当 qQueue 继续有值时,继续处理 + // when(qQueue.io.deq.valid && kQueue.io.deq.valid) { + // qQueue.io.deq.ready := true.B + // }.otherwise { + // qQueue.io.deq.ready := false.B + // } + + // when(io.resetBuffer) { + // kCounter := 0.U + // scoreValue := VecInit(Seq.fill(m)(VecInit(Seq.fill(m)(0.U(config.outputWidth.W))))) + // } + + // // final result idx + // val rowIdx = RegInit(0.U(log2Ceil(m / PECount2).W)) + // val colIdx = RegInit(0.U(log2Ceil(m / PECount2).W)) + // val resValid = RegInit(false.B) + // io.score.valid := resValid + + // io.score.bits := scoreValue + + // when(resValid && io.score.ready) { + // resValid := false.B + // } + +} diff --git a/src/test/scala/kernel/alu/GemmFMATest.scala b/src/test/scala/kernel/alu/GemmFMATest.scala new file mode 100644 index 0000000..491a017 --- /dev/null +++ b/src/test/scala/kernel/alu/GemmFMATest.scala @@ -0,0 +1,609 @@ +package kernel.alu + +import chisel3._ +import chiseltest._ +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.ParallelTestExecution +import scala.reflect.ClassTag +import kernel.alu.{DataWidthConfig, Fp32Config, Fp64Config, FxpConfig, GEMMDataType} + +class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestExecution { + + def mmul[T: Numeric: ClassTag](a: Array[Array[T]], b: Array[Array[T]]): Array[Array[T]] = { + val rows = a.length + val cols = b(0).length + val n = b.length + val num = implicitly[Numeric[T]] + + Array.tabulate(rows, cols) { (i, j) => + var sum = num.zero + for (k <- 0 until n) { + sum = num.plus(sum, num.times(a(i)(k), b(k)(j))) + } + sum + } + } + + def matInit[T: Numeric: ClassTag]( + rows: Int, + cols: Int + )( + implicit config: DataWidthConfig + ): Array[Array[T]] = { + val r = new scala.util.Random(42) + val ct = implicitly[ClassTag[T]] + val numeric = implicitly[Numeric[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + // 定点数使用 -8 到 7 的整数 + Array.fill(rows, cols)( + numeric.fromInt( + // r.nextInt(math.pow(2, config.inputWidth).toInt) - math.pow(2, config.inputWidth - 1).toInt + r.nextInt(4) - 2 + ) + ) + case c if c == classOf[Float] => + // 32位浮点数使用 -1 到 1 的随机浮点数 + // Float 类型 + Array.fill(rows, cols)((r.nextFloat() * 2 - 1).asInstanceOf[T]) + case c if c == classOf[Double] => + // 64位浮点数使用 -1 到 1 的随机浮点数 + Array.fill(rows, cols)((r.nextDouble() * 2 - 1).asInstanceOf[T]) + case _ => + throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") + } + } + + def toSignedBigInt(value: BigInt, width: Int): BigInt = { + val signBit = (value >> (width - 1)) & 1 + + if (signBit == 1) { + val maxValue = BigInt(1) << width + value - maxValue + } else { + value + } + } + + def printmat[T: Numeric: ClassTag](m: Array[Array[T]]): Unit = { + val numeric = implicitly[Numeric[T]] + val ct = implicitly[ClassTag[T]] + + m.foreach { r => + r.foreach { v => + ct.runtimeClass match { + case c if c == classOf[Float] => + print(f"${v.asInstanceOf[Float]}%.4f\t") + case c if c == classOf[Double] => + print(f"${v.asInstanceOf[Double]}%.4f\t") + case c if c == classOf[Int] => + print(f"${v.asInstanceOf[Int]}%d\t") + case _ => + throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") + } + } + println(";") + } + println() + } + + def printmat[T: Numeric: ClassTag](m: Array[T], x: Int, y: Int)(implicit config: DataWidthConfig): Unit = { + val numeric = implicitly[Numeric[T]] + val ct = implicitly[ClassTag[T]] + + for (i <- 0 until x) { + for (j <- 0 until y) { + ct.runtimeClass match { + case c if c == classOf[Float] => + print(f"${m(i * y + j).asInstanceOf[Float]}%.4f\t") + case c if c == classOf[Double] => + print(f"${m(i * y + j).asInstanceOf[Double]}%.4f\t") + case c if c == classOf[Int] => + print(f"${m(i * y + j).asInstanceOf[Int]}%d\t") + case c if c == classOf[BigInt] => + print(f"${toSignedBigInt(m(i * y + j).asInstanceOf[BigInt], config.inputWidth)}%d\t") + case _ => + throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") + } + } + println(";") + } + println() + } + + // convert T to binary bigInt + def toBinaryBigInt[T: Numeric: ClassTag](v: T)(implicit config: DataWidthConfig): BigInt = { + val ct = implicitly[ClassTag[T]] + val num = implicitly[Numeric[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + val intValue = v.asInstanceOf[Int] + // 使用 inputWidth 位来表示所有整数,保持符号位 + val mask = (1L << config.inputWidth) - 1 + BigInt(intValue) & mask + case c if c == classOf[Float] => + BigInt(java.lang.Float.floatToRawIntBits(v.asInstanceOf[Float]).toBinaryString, 2) + case c if c == classOf[Double] => + BigInt(java.lang.Double.doubleToRawLongBits(v.asInstanceOf[Double]).toBinaryString, 2) + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") + } + } + + // convrt T to binary string + private def toBinaryString[T: Numeric: ClassTag](v: T)(implicit config: DataWidthConfig): String = { + val ct = implicitly[ClassTag[T]] + val num = implicitly[Numeric[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + val intBValue = v.asInstanceOf[Int].toBinaryString + if (intBValue.length < config.inputWidth) { + intBValue.reverse.padTo(config.inputWidth, '0').reverse + } else { + intBValue.takeRight(config.inputWidth) + } + case c if c == classOf[Float] => + java.lang.Float + .floatToRawIntBits(v.asInstanceOf[Float]) + .toBinaryString + .reverse + .padTo(config.inputWidth, '0') + .reverse + case c if c == classOf[Double] => + java.lang.Double + .doubleToRawLongBits(v.asInstanceOf[Double]) + .toBinaryString + .reverse + .padTo(config.inputWidth, '0') + .reverse + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") + } + } + + // convert binary bigInt to T + def fromBinaryBigInt[T: Numeric: ClassTag](bigInt: BigInt)(implicit config: DataWidthConfig): T = { + val ct = implicitly[ClassTag[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + val intValue = bigInt.toInt + // 处理符号位 + val signExtendedValue = if ((intValue & (1 << (config.inputWidth - 1))) != 0) { + intValue | ~((1 << config.inputWidth) - 1) + } else { + intValue + } + signExtendedValue.asInstanceOf[T] + case c if c == classOf[Float] => + java.lang.Float.intBitsToFloat(bigInt.toInt).asInstanceOf[T] + case c if c == classOf[Double] => + java.lang.Double.longBitsToDouble(bigInt.toLong).asInstanceOf[T] + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") + } + } + private def testMultiFMAMM[T: Numeric: ClassTag]( + dut: MultiFMAMM + )( + implicit config: DataWidthConfig + ): Unit = { + val k = dut.k + val PECount = dut.PECount + val gemmType = dut.gemmType + + val matrixA_row = matInit[T](1, k) + val matrixB_cols = matInit[T](k, PECount) + val expectedResults = mmul(matrixA_row, matrixB_cols) + printmat(matrixA_row) + printmat(matrixB_cols) + printmat(expectedResults) + // 初始化输入 + dut.io.reset.poke(true.B) + dut.clock.step(1) + dut.io.reset.poke(false.B) + + // 逐元素输入数据 + for (i <- matrixA_row(0).indices) { + dut.io.matrixA_row(i).poke(toBinaryBigInt(matrixA_row(0)(i)).U) + for (j <- 0 until PECount) { + dut.io.matrixB_cols(j)(i).poke(toBinaryBigInt(matrixB_cols(i)(j)).U) + } + } + + while (!dut.io.valids.forall(_.peekBoolean())) { + dut.clock.step() + } + + val precision = 0.001f + var invalidcnt = 0 + + for (i <- 0 until PECount) { + val outBigInt = dut.io.results(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(0)(i) + + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + }) + + if (isInvalid) { + println("Error: ") + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + dut.io.valids(i).expect(true.B) + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testGEMMFMA[T: Numeric: ClassTag]( + dut: GEMMFMA + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val PECount = dut.PECount + val gemmType = dut.gemmType + + val matrixA = matInit[T](m, k) + val matrixB = matInit[T](k, n) + val expectedResults = mmul(matrixA, matrixB) + + // 初始化输入 + for (row <- 0 until m) { + for (col <- 0 until n) { + for (i <- 0 until k) { + dut.io.matrixA(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) + } + } + } + + while (!dut.io.done.peekBoolean()) { + dut.clock.step() + } + val precision = 0.001f + var invalidcnt = 0 + + for (row <- 0 until m) { + for (col <- 0 until n) { + val outBigInt = dut.io.results(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + }) + // printmat(Array(Array(out))) + // printmat(Array(Array(expected))) + if (isInvalid) { + println("Error: ") + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testGEMMFMATotal[T: Numeric: ClassTag]( + dut: GEMMFMATotal + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val PECount = dut.PECount + val gemmType = dut.gemmType + + val matrixA = matInit[T](m, k) + val matrixB = matInit[T](k, n) + val expectedResults = mmul(matrixA, matrixB) + // printmat(expectedResults) + + if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { + println("matrixA and matrixB are ready") + dut.io.matrixA.valid.poke(true.B) + dut.io.matrixB.valid.poke(true.B) + for (row <- 0 until m) { + for (col <- 0 until n) { + for (i <- 0 until k) { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) + } + } + } + } else { + dut.io.matrixA.valid.poke(false.B) + dut.io.matrixB.valid.poke(false.B) + } + + while (!dut.io.results.valid.peekBoolean()) { + dut.clock.step() + } + + val precision = 0.001f + var invalidcnt = 0 + + for (row <- 0 until m) { + for (col <- 0 until n) { + val outBigInt = dut.io.results.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + }) + // printmat(Array(Array(out))) + // printmat(Array(Array(expected))) + if (isInvalid) { + println("Error: ") + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testGEMMFMASingle[T: Numeric: ClassTag]( + dut: GEMMFMASingle + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val PECount = dut.PECount + val gemmType = dut.gemmType + + val matrixA = matInit[T](m, k) + val matrixB = matInit[T](k, n) + val expectedResults = mmul(matrixA, matrixB) + // printmat(expectedResults) + + if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { + println("matrixA and matrixB are ready") + dut.io.matrixA.valid.poke(true.B) + dut.io.matrixB.valid.poke(true.B) + for (row <- 0 until m) { + for (col <- 0 until n) { + for (i <- 0 until k) { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) + } + } + } + } else { + dut.io.matrixA.valid.poke(false.B) + dut.io.matrixB.valid.poke(false.B) + } + + val precision = 0.001f + var invalidcnt = 0 + + while (!dut.io.done.peekBoolean()) { + if (dut.io.currentRow.valid.peekBoolean()) { + val currentRowIndex = dut.io.currentRow.bits.index.peekInt() + println("currentRow index: " + currentRowIndex) + for (i <- 0 until n) { + val outBigInt = dut.io.currentRow.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(currentRowIndex.toInt)(i) + println("i: " + i) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + }) + if (isInvalid) { + println("Error: ") + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + dut.clock.step() + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testQKMulFMA[T: Numeric: ClassTag]( + dut: QKMulFMA + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val PECount1 = 4 + val PECount2 = 4 + val gemmType = dut.gemmType + + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val W_q = mmul(inputToken, weightQ) + val W_k = mmul(inputToken, weightK) + val expectedResults = mmul(W_q, W_k.transpose) // W_q * W_k^T + + if ( + dut.io.inputToken.ready.peekBoolean() && dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() + ) { + println("inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for (row <- 0 until m) { + for (col <- 0 until n) { + for (i <- 0 until k) { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + } + } + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + + while (!dut.io.score.valid.peekBoolean()) { + dut.clock.step() + } + + val precision = 0.001f + var invalidcnt = 0 + + for (row <- 0 until m) { + for (col <- 0 until m) { + val outBigInt = dut.io.score.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + }) + if (isInvalid) { + println("Error: ") + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + // "QKMulFMA " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKMulFMA(m = 4, k = 4, n = 4, PECount1 = 4, PECount2 = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKMulFMA[Int](dut) + // } + // } + // "GEMMFMATotal " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new GEMMFMATotal(m = 4, k = 8, n = 12, PECount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMATotal[Int](dut) + // } + // } + + // "GEMMFMASingle " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new GEMMFMASingle(m = 4, k = 8, n = 12, PECount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMASingle[Float](dut) + // } + // } + + "GEMMFMASingle " should "compute fxp matrix multiplication" in { + implicit val config: DataWidthConfig = FxpConfig + test(new GEMMFMASingle(m = 4, k = 8, n = 12, PECount = 4, gemmType = GEMMDataType.Fxp)) + .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + testGEMMFMASingle[Int](dut) + } + } + + // "MultiFMAMM " should "compute fp32 dot product" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new MultiFMAMM(k = 4, PECount = 16, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testMultiFMAMM[Float](dut) + // } + // } + + // "MultiFMAMM " should "compute fp64 dot product" in { + // implicit val config: DataWidthConfig = Fp64Config + // test(new MultiFMAMM(k = 4, PECount = 16, gemmType = GEMMDataType.Fp64)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testMultiFMAMM[Double](dut) + // } + // } + + // "MultiFMAMM " should "compute fxp dot product" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new MultiFMAMM(k = 4, PECount = 16, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testMultiFMAMM[Int](dut) + // } + // } + + // "GEMMFMA " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new GEMMFMA(m = 4, k = 4, n = 16, PECount = 16, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMA[Float](dut) + // } + // } + + // "GEMMFMA " should "compute fp64 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp64Config + // test(new GEMMFMA(m = 4, k = 4, n = 16, PECount = 16, gemmType = GEMMDataType.Fp64)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMA[Double](dut) + // } + // } + + // "GEMMFMA " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new GEMMFMA(m = 4, k = 4, n = 16, PECount = 16, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMA[Int](dut) + // } + // } +} From 918111d3e7335fbf765b57513ef14be405ac84ac Mon Sep 17 00:00:00 2001 From: pyfirstcsh Date: Wed, 25 Dec 2024 17:54:56 +0800 Subject: [PATCH 02/10] add:QKMulFMASingle test bug --- src/main/scala/kernel/alu/GemmFMA.scala | 140 ++++++++++++++++- src/test/scala/kernel/alu/GemmFMATest.scala | 162 +++++++++++++++----- 2 files changed, 265 insertions(+), 37 deletions(-) diff --git a/src/main/scala/kernel/alu/GemmFMA.scala b/src/main/scala/kernel/alu/GemmFMA.scala index f6cf4fc..e35aaee 100644 --- a/src/main/scala/kernel/alu/GemmFMA.scala +++ b/src/main/scala/kernel/alu/GemmFMA.scala @@ -253,7 +253,6 @@ class GEMMFMASingle( val readyReg = RegInit(true.B) io.matrixA.ready := readyReg io.matrixB.ready := readyReg - io.currentRow.valid := false.B io.currentRow.bits := DontCare @@ -321,7 +320,146 @@ class GEMMFMASingle( io.done := doneFlag } +class GEMMSingleQueue( + val m: Int, + val k: Int, + val n: Int, + val PECount: Int = 16, + val gemmType: GEMMDataType.Type, + val bufferSize: Int = 32 +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) // 矩阵A + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) // 矩阵B + val flush = Input(Bool()) + val currentRow = Decoupled(new currentRowIndex(m, n)) + val done = Output(Bool()) + }) + + val currentBuffer = Module( + new Queue( + new currentRowIndex(m, n), + entries = bufferSize, + pipe = true, + flow = false, + useSyncReadMem = false, + hasFlush = true + ) + ) + val gemm = Module(new GEMMFMASingle(m, k, n, PECount, gemmType)) + gemm.io.matrixA <> io.matrixA + gemm.io.matrixB <> io.matrixB + currentBuffer.io.flush.get := io.flush + currentBuffer.io.enq <> gemm.io.currentRow + io.currentRow <> currentBuffer.io.deq + io.done := gemm.io.done + +} + +// first use GEMMFMATotal to get Q and K, then use GEMMFMASingle to get Q*K^T +// out one row of score matrix +class QKMulFMASingle( + val m: Int, + val k: Int, + val n: Int, + val PECount1: Int = 16, + val PECount2: Int = 16, + val gemmType: GEMMDataType.Type, + val bufferSizeGemm: Int = 32 +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val scoreRow = Decoupled(new currentRowIndex(m, m)) + val resetBuffer = Input(Bool()) + val done = Output(Bool()) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + val readyReg = RegInit(true.B) + io.inputToken.ready := readyReg + io.weightQ.ready := readyReg + io.weightK.ready := readyReg + io.scoreRow.valid := false.B + io.scoreRow.bits := DontCare + io.done := false.B + + //use GEMMFMATotal to get Q and K + val qGen = Module(new GEMMFMATotal(m, k, n, PECount1, gemmType)) + val kGen = Module(new GEMMFMATotal(m, k, n, PECount1, gemmType)) + qGen.io.matrixA <> io.inputToken + qGen.io.matrixB <> io.weightQ + kGen.io.matrixA <> io.inputToken + kGen.io.matrixB <> io.weightK + + // when qGen and kGen are done, use GEMMFMASingle to get Q*K^T + // Q: m * n, K: m * n -> Q*K^T: m * m + val QK_TMul = Module(new GEMMSingleQueue(m, n, m, PECount2, gemmType, bufferSizeGemm)) + QK_TMul.io.matrixA <> qGen.io.results + + val K_T = VecInit(Seq.fill(n)(VecInit(Seq.fill(m)(0.U(config.inputWidth.W))))) + for (i <- 0 until k) { + for (j <- 0 until n) { + K_T(i)(j) := kGen.io.results.bits(j)(i) + } + } + + QK_TMul.io.matrixB.valid := kGen.io.results.valid + // QK_TMul.io.matrixB.bits := VecInit(kGen.io.results.bits.transpose.map(VecInit(_))) + QK_TMul.io.matrixB.bits := K_T + kGen.io.results.ready := QK_TMul.io.matrixB.ready + + QK_TMul.io.flush := io.resetBuffer + io.scoreRow <> QK_TMul.io.currentRow + + object state extends ChiselEnum { + val idle, gen, mul, collect, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.gen + } + } + is(state.gen) { + when(qGen.io.results.valid && kGen.io.results.valid) { + debugLog(p"qGen results: ${qGen.io.results.bits}\n") + debugLog(p"kGen results: ${kGen.io.results.bits}\n") + stateReg := state.mul + } + } + is(state.mul) { + when(QK_TMul.io.currentRow.valid) { + stateReg := state.collect + } + } + is(state.collect) { + when(QK_TMul.io.done) { + stateReg := state.done + } + } + is(state.done) { + io.done := true.B + readyReg := true.B + stateReg := state.idle + } + } + +} + // TODO: 优化,bug +//first use GEMMFMATotal to get Q and K, then use GEMMFMASingle to get Q*K^T class QKMulFMA( val m: Int, val k: Int, diff --git a/src/test/scala/kernel/alu/GemmFMATest.scala b/src/test/scala/kernel/alu/GemmFMATest.scala index 491a017..f501539 100644 --- a/src/test/scala/kernel/alu/GemmFMATest.scala +++ b/src/test/scala/kernel/alu/GemmFMATest.scala @@ -452,8 +452,81 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe else println(s"Verification failed with $invalidcnt errors.") } - private def testQKMulFMA[T: Numeric: ClassTag]( - dut: QKMulFMA + private def testGEMMSingleQueue[T: Numeric: ClassTag]( + dut: GEMMSingleQueue + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + val matrixA = matInit[T](m, k) + val matrixB = matInit[T](k, n) + val expectedResults = mmul(matrixA, matrixB) + printmat(expectedResults) + + if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { + println("matrixA and matrixB are ready") + dut.io.matrixA.valid.poke(true.B) + dut.io.matrixB.valid.poke(true.B) + for (row <- 0 until m) { + for (col <- 0 until n) { + for (i <- 0 until k) { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) + } + } + } + } else { + dut.io.matrixA.valid.poke(false.B) + dut.io.matrixB.valid.poke(false.B) + } + + dut.io.currentRow.ready.poke(true.B) + + val precision = 0.001f + var invalidcnt = 0 + + while (!dut.io.done.peekBoolean()) { + if (dut.io.currentRow.valid.peekBoolean()) { + val currentRowIndex = dut.io.currentRow.bits.index.peekInt() + println("currentRow index: " + currentRowIndex) + for (i <- 0 until n) { + val outBigInt = dut.io.currentRow.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(currentRowIndex.toInt)(i) + println("i: " + i) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: ") + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + dut.clock.step() + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + private def testQKMulFMASingle[T: Numeric: ClassTag]( + dut: QKMulFMASingle )( implicit config: DataWidthConfig ): Unit = { @@ -468,8 +541,11 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val weightQ = matInit[T](k, n) val weightK = matInit[T](k, n) val W_q = mmul(inputToken, weightQ) + printmat(W_q) val W_k = mmul(inputToken, weightK) + printmat(W_k.transpose) val expectedResults = mmul(W_q, W_k.transpose) // W_q * W_k^T + printmat(expectedResults) if ( dut.io.inputToken.ready.peekBoolean() && dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() @@ -493,46 +569,60 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe dut.io.weightK.valid.poke(false.B) } - while (!dut.io.score.valid.peekBoolean()) { - dut.clock.step() - } - + dut.io.scoreRow.ready.poke(true.B) val precision = 0.001f var invalidcnt = 0 - - for (row <- 0 until m) { - for (col <- 0 until m) { - val outBigInt = dut.io.score.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - }) - if (isInvalid) { - println("Error: ") + while (!dut.io.done.peekBoolean()) { + if (dut.io.scoreRow.valid.peekBoolean()) { + val scoreRowIndex = dut.io.scoreRow.bits.index.peekInt() + println("scoreRow index: " + scoreRowIndex) + for (i <- 0 until n) { + val outBigInt = dut.io.scoreRow.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(scoreRowIndex.toInt)(i) + println("i: " + i) printmat(Array(Array(out))) printmat(Array(Array(expected))) - invalidcnt += 1 + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: ") + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } } } + dut.clock.step() } + if (invalidcnt == 0) println("Verification passed!") else println(s"Verification failed with $invalidcnt errors.") } - // "QKMulFMA " should "compute fxp matrix multiplication" in { + "QKMulFMASingle " should "compute fxp matrix multiplication" in { + implicit val config: DataWidthConfig = FxpConfig + test(new QKMulFMASingle(m = 4, k = 4, n = 4, PECount1 = 4, PECount2 = 4, gemmType = GEMMDataType.Fxp)) + .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + testQKMulFMASingle[Int](dut) + } + } + + // "GEMMSingleQueue " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new QKMulFMA(m = 4, k = 4, n = 4, PECount1 = 4, PECount2 = 4, gemmType = GEMMDataType.Fxp)) + // test(new GEMMSingleQueue(m = 4, k = 8, n = 12, PECount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testQKMulFMA[Int](dut) + // testGEMMSingleQueue[Int](dut) // } // } // "GEMMFMATotal " should "compute fxp matrix multiplication" in { @@ -551,13 +641,13 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe // } // } - "GEMMFMASingle " should "compute fxp matrix multiplication" in { - implicit val config: DataWidthConfig = FxpConfig - test(new GEMMFMASingle(m = 4, k = 8, n = 12, PECount = 4, gemmType = GEMMDataType.Fxp)) - .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - testGEMMFMASingle[Int](dut) - } - } + // "GEMMFMASingle " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new GEMMFMASingle(m = 4, k = 8, n = 12, PECount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMASingle[Int](dut) + // } + // } // "MultiFMAMM " should "compute fp32 dot product" in { // implicit val config: DataWidthConfig = Fp32Config From bc8d2bf494f5d18ae9ca576c383896e54a03a32d Mon Sep 17 00:00:00 2001 From: pyfirstcsh Date: Sat, 28 Dec 2024 15:09:28 +0800 Subject: [PATCH 03/10] update but still error --- src/main/scala/kernel/alu/GemmFMA.scala | 1019 +++++++++++++------ src/test/scala/kernel/alu/GemmFMATest.scala | 749 +++++++++++--- 2 files changed, 1282 insertions(+), 486 deletions(-) diff --git a/src/main/scala/kernel/alu/GemmFMA.scala b/src/main/scala/kernel/alu/GemmFMA.scala index e35aaee..341e752 100644 --- a/src/main/scala/kernel/alu/GemmFMA.scala +++ b/src/main/scala/kernel/alu/GemmFMA.scala @@ -5,6 +5,7 @@ import chisel3.util._ import kernel.alu.GEMMDataType import kernel.alu.DataWidthConfig import kernel.utils.DebugLog +import kernel.deprecated.PE class currentRowIndex( val m: Int, @@ -16,164 +17,180 @@ class currentRowIndex( val value = Output(Vec(n, UInt(config.outputWidth.W))) //输出的行值 } -class MultiFMAMM( - val k: Int = 64, - val PECount: Int = 16, +// input: matrixA_row: one row of matrixA : 1 * k +// input: matrixB_cols: peCount rows of matrixB : k * peCount +// input: reset: clear pes old data +// output: blockResult: one block of result : 1 * peCount +class MultiFMA( + val k: Int, + val peCount: Int, val gemmType: GEMMDataType.Type )( implicit config: DataWidthConfig) extends Module with DebugLog { val io = IO(new Bundle { - val matrixA_row = Input(Vec(k, UInt(config.inputWidth.W))) // 矩阵1的一行 - val matrixB_cols = Input(Vec(PECount, Vec(k, UInt(config.inputWidth.W)))) // 矩阵2的PECount列向量 - val results = Output(Vec(PECount, UInt(config.outputWidth.W))) // 点积结果 - val valids = Output(Vec(PECount, Bool())) // 结果有效标志 + val matrixA_row = Flipped(Decoupled(Vec(k, UInt(config.inputWidth.W)))) + val matrixB_cols = Flipped(Decoupled(Vec(k, Vec(peCount, UInt(config.inputWidth.W))))) + val blockResult = Decoupled(Vec(peCount, UInt(config.outputWidth.W))) val reset = Input(Bool()) }) - // 创建 PECount 个 PE 实例 - val pes = Seq.fill(PECount)(gemmType match { + val dataValid = io.matrixA_row.valid && io.matrixB_cols.valid + + val readyReg = RegInit(true.B) + io.matrixA_row.ready := readyReg + io.matrixB_cols.ready := readyReg + io.blockResult.valid := false.B + io.blockResult.bits := DontCare + + val pes = Seq.fill(peCount)(gemmType match { case GEMMDataType.Fxp => Module(new PEFxp()).io case GEMMDataType.Fp32 => Module(new PEFp()).io case GEMMDataType.Fp64 => Module(new PEFp()).io case _ => throw new IllegalArgumentException("Unsupported GEMM type") }) - // 当前索引寄存器 - val index = RegInit(0.U(log2Ceil(k).W)) - // 结果有效标志 - val valid = RegInit(false.B) - - // 连接每个 PE 的输入和输出 - for (i <- 0 until PECount) { - val pe = pes(i) - pe.reset := io.reset - pe.in_h := io.matrixA_row(index) // 矩阵1的当前行值 - pe.in_v := io.matrixB_cols(i)(index) // 矩阵2的第i列当前值 - io.results(i) := pe.out // 输出结果 - io.valids(i) := valid // 当点积完成时置位 valid + val optIndex = RegInit(0.U(log2Ceil(k).W)) + val validReg = RegInit(false.B) + + for (i <- 0 until peCount) { + pes(i).reset := io.reset + pes(i).in_h := io.matrixA_row.bits(optIndex) + pes(i).in_v := io.matrixB_cols.bits(optIndex)(i) + io.blockResult.bits(i) := pes(i).out } + io.blockResult.valid := validReg - // 索引、结果有效位控制逻辑 + when(dataValid) { + readyReg := false.B + } when(io.reset) { - index := 0.U - valid := false.B - }.elsewhen(index =/= (k - 1).U) { - index := index + 1.U - valid := false.B - }.elsewhen(index === (k - 1).U) { - index := 0.U - valid := true.B + optIndex := 0.U + validReg := false.B + }.elsewhen(optIndex === (k - 1).U) { + optIndex := 0.U + validReg := true.B + readyReg := true.B + }.otherwise { + validReg := false.B + optIndex := optIndex + 1.U } -} - -class GEMMFMA( - val m: Int, - val k: Int, - val n: Int, - val PECount: Int = 16, - val gemmType: GEMMDataType.Type -)( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - // require(m % PECount == 0 && k % PECount == 0 && n % PECount == 0, "Matrix dimensions must be divisible by PECount") - val io = IO(new Bundle { - val matrixA = Input(Vec(m, Vec(k, UInt(config.inputWidth.W)))) // 矩阵A - val matrixB = Input(Vec(k, Vec(n, UInt(config.inputWidth.W)))) // 矩阵B - val results = Output(Vec(m, Vec(n, UInt(config.outputWidth.W)))) // 结果矩阵 - val done = Output(Bool()) // 完成标志 - }) - val rowIndex = RegInit(0.U(log2Ceil(m).W)) // 当前行索引 - val colIndex = RegInit(0.U(log2Ceil(n).W)) // 当前列块索引 - val resultMatrix = Reg(Vec(m, Vec(n, UInt(config.outputWidth.W)))) // 存储结果 - val doneFlag = RegInit(false.B) // 完成标志 - val resetFlag = RegInit(false.B) // 复位标志 + // TODO: FSM :reset logic is not correct, need to be fixed - // 实例化MultiFMA - val multiFMA = Module(new MultiFMAMM(k, PECount, gemmType)) - - // 输入连接 - multiFMA.io.matrixA_row := io.matrixA(rowIndex) - multiFMA.io.matrixB_cols := VecInit(Seq.tabulate(PECount) { i => - VecInit(io.matrixB.map(_(colIndex + i.U))) - }) + // pes.foreach { pe => + // pe.in_h := 0.U + // pe.in_v := 0.U + // pe.reset := DontCare + // } - // 一块处理结束, 结果存储, 更新块索引 - when(multiFMA.io.valids.reduce(_ && _)) { - for (i <- 0 until PECount) { - resultMatrix(rowIndex)(colIndex + i.U) := multiFMA.io.results(i) - } - resetFlag := true.B - // 更新块索引 - when(colIndex === (n - PECount).U) { - colIndex := 0.U - when(rowIndex === (m - 1).U) { - rowIndex := 0.U - doneFlag := true.B - }.otherwise { - rowIndex := rowIndex + 1.U - } - }.otherwise { - colIndex := colIndex + PECount.U - } - } + // io.blockResult.valid := validReg + // io.blockResult.bits := DontCare - // 复位控制,也即清空累加器 - when(resetFlag) { - multiFMA.io.reset := true.B - resetFlag := false.B - }.otherwise { - multiFMA.io.reset := false.B - } - - io.done := doneFlag - io.results := resultMatrix + // object state extends ChiselEnum { + // val idle, reset, compute, update, done = Value + // } + // val stateReg = RegInit(state.idle) + + // switch(stateReg) { + // is(state.idle) { + // when(dataValid) { + // readyReg := false.B + // stateReg := state.compute + // } + // } + // is(state.compute) { + // when(io.reset) { + // stateReg := state.reset + // } + // for (i <- 0 until peCount) { + // pes(i).reset := io.reset + // pes(i).in_h := io.matrixA_row.bits(optIndex) + // pes(i).in_v := io.matrixB_cols.bits(optIndex)(i) + // io.blockResult.bits(i) := pes(i).out + // } + + // // printf(p"optIndex: ${optIndex}\n") + // // printf(p"io.matrixA_row.bits(${optIndex}): ${io.matrixA_row.bits(optIndex)}\n") + // // for (i <- 0 until peCount) { + // // printf(p"pe: $i\n") + // // printf(p"io.matrixB_cols.bits(${optIndex})($i): ${io.matrixB_cols.bits(optIndex)(i)}\n") + // // printf(p"io.blockResult.bits(${i}): ${io.blockResult.bits(i)}\n") + // // } + // stateReg := state.update + // } + // is(state.reset) { + // optIndex := 0.U + // validReg := false.B + // stateReg := state.idle + // } + // is(state.update) { + // validReg := false.B + // when(optIndex === (k - 1).U) { + // stateReg := state.done + // }.otherwise { + // optIndex := optIndex + 1.U + // stateReg := state.compute + // } + // } + // is(state.done) { + // optIndex := 0.U + // readyReg := true.B + // validReg := true.B + // stateReg := state.idle + // } + // } } +// input: matrixA: m * k +// input: matrixB: k * n +// output: matrixC: m * n class GEMMFMATotal( val m: Int, val k: Int, val n: Int, - val PECount: Int = 16, + val peCount: Int, val gemmType: GEMMDataType.Type )( implicit config: DataWidthConfig) extends Module with DebugLog { - require(m % PECount == 0 && k % PECount == 0 && n % PECount == 0, "Matrix dimensions must be divisible by PECount") + require(m % peCount == 0 && k % peCount == 0 && n % peCount == 0, "Matrix dimensions must be divisible by peCount") val io = IO(new Bundle { - val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) // 矩阵A - val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) // 矩阵B - val results = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) // 结果矩阵 + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val results = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) }) - val rowIndex = Counter(m) - val colIndex = Counter(n / PECount) - - val resultsReg = Reg(Vec(m, Vec(n, UInt(config.outputWidth.W)))) // 存储结果 - val resetFlag = RegInit(false.B) // 复位标志,用于清空MultiFMA - val dataValid = io.matrixA.valid && io.matrixB.valid + val readyReg = RegInit(true.B) io.matrixA.ready := readyReg io.matrixB.ready := readyReg - io.results.valid := false.B io.results.bits := DontCare - // 实例化MultiFMA - val multiFMA = Module(new MultiFMAMM(k, PECount, gemmType)) + val multiFMA = Module(new MultiFMA(k, peCount, gemmType)) - // 输入连接 - multiFMA.io.matrixA_row := io.matrixA.bits(rowIndex.value) - multiFMA.io.matrixB_cols := VecInit(Seq.tabulate(PECount) { i => - VecInit(io.matrixB.bits.map(_((colIndex.value * PECount.U + i.U) % n.U))) - }) + val rowIndex = Counter(m) + val colIndex = Counter(n / peCount) + + multiFMA.io.matrixA_row.valid := io.matrixA.valid + multiFMA.io.matrixA_row.bits := io.matrixA.bits(rowIndex.value) + + multiFMA.io.matrixB_cols.valid := io.matrixB.valid + multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(k) { j => + VecInit(Seq.tabulate(peCount) { i => + io.matrixB.bits(j)((colIndex.value * peCount.U + i.U) % n.U) + }) + }) //k * peCount size block of matrixB + + multiFMA.io.reset := false.B + multiFMA.io.blockResult.ready := true.B + + val resultsReg = Reg(Vec(m, Vec(n, UInt(config.outputWidth.W)))) - // 状态机定义 object state extends ChiselEnum { val idle, compute, update, done = Value } @@ -188,16 +205,28 @@ class GEMMFMATotal( } is(state.compute) { - when(multiFMA.io.valids.reduce(_ && _)) { - for (i <- 0 until PECount) { - resultsReg(rowIndex.value)((colIndex.value * PECount.U + i.U) % n.U) := multiFMA.io.results(i) + // printf(p"rowIndex: ${rowIndex.value}, colIndex: ${colIndex.value}\n") + // printf(p"matrixA_row: ${multiFMA.io.matrixA_row.bits}\n") + // for (i <- 0 until peCount) { + // printf(p"matrixB_cols(${i}): ${multiFMA.io.matrixB_cols.bits(i)}\n") + // } + multiFMA.io.reset := false.B + when(multiFMA.io.blockResult.valid) { + for (i <- 0 until peCount) { + resultsReg(rowIndex.value)((colIndex.value * peCount.U + i.U) % n.U) := multiFMA.io.blockResult.bits(i) } - resetFlag := true.B + // for (i <- 0 until m) { + // for (j <- 0 until n) { + // printf(p"i:${i},j:${j},${resultsReg(i)(j)}\t") + // } + // printf(p"\n") + // } stateReg := state.update } } is(state.update) { + multiFMA.io.reset := true.B when(colIndex.inc()) { when(rowIndex.inc()) { stateReg := state.done @@ -216,56 +245,59 @@ class GEMMFMATotal( stateReg := state.idle } } - when(resetFlag) { - multiFMA.io.reset := true.B - resetFlag := false.B - }.otherwise { - multiFMA.io.reset := false.B - } } + +//input: matrixA: m * k +//input: matrixB: k * n +//output: currentRowIndex: one row of matrixC: 1 * n and current row index +//output: done: total matrixC finish flag class GEMMFMASingle( val m: Int, val k: Int, val n: Int, - val PECount: Int = 16, + val peCount: Int, val gemmType: GEMMDataType.Type )( implicit config: DataWidthConfig) extends Module with DebugLog { - require(m % PECount == 0 && k % PECount == 0 && n % PECount == 0, "Matrix dimensions must be divisible by PECount") + require(m % peCount == 0 && k % peCount == 0 && n % peCount == 0, "Matrix dimensions must be divisible by peCount") val io = IO(new Bundle { - val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) // 矩阵A - val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) // 矩阵B - val currentRow = Decoupled(new currentRowIndex(m, n)) //输出的行索引 - val done = Output(Bool()) // 整个矩阵完成标志 + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val currentRow = Decoupled(new currentRowIndex(m, n)) + val done = Output(Bool()) }) - val rowIndex = Counter(m) - val colIndex = Counter(n / PECount) - - val currentRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) // 存储当前行结果 - val doneFlag = RegInit(false.B) // 完成标志 - val resetFlag = RegInit(false.B) // 复位标志 - val dataValid = io.matrixA.valid && io.matrixB.valid val readyReg = RegInit(true.B) io.matrixA.ready := readyReg io.matrixB.ready := readyReg io.currentRow.valid := false.B io.currentRow.bits := DontCare + io.done := false.B - // 实例化MultiFMA - val multiFMA = Module(new MultiFMAMM(k, PECount, gemmType)) + val multiFMA = Module(new MultiFMA(k, peCount, gemmType)) - // 输入连接 - multiFMA.io.matrixA_row := io.matrixA.bits(rowIndex.value) - multiFMA.io.matrixB_cols := VecInit(Seq.tabulate(PECount) { i => - VecInit(io.matrixB.bits.map(_((colIndex.value * PECount.U + i.U) % n.U))) - }) + val rowIndex = Counter(m) + val colIndex = Counter(n / peCount) + + multiFMA.io.matrixA_row.valid := io.matrixA.valid + multiFMA.io.matrixA_row.bits := io.matrixA.bits(rowIndex.value) + + multiFMA.io.matrixB_cols.valid := io.matrixB.valid + multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(k) { j => + VecInit(Seq.tabulate(peCount) { i => + io.matrixB.bits(j)((colIndex.value * peCount.U + i.U) % n.U) + }) + }) //k * peCount size block of matrixB + + multiFMA.io.reset := false.B + multiFMA.io.blockResult.ready := true.B + + val currentRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) - // 状态机定义 object state extends ChiselEnum { val idle, compute, update, done = Value } @@ -280,16 +312,17 @@ class GEMMFMASingle( } is(state.compute) { - when(multiFMA.io.valids.reduce(_ && _)) { - for (i <- 0 until PECount) { - currentRowReg((colIndex.value * PECount.U + i.U) % n.U) := multiFMA.io.results(i) + multiFMA.io.reset := false.B + when(multiFMA.io.blockResult.valid) { + for (i <- 0 until peCount) { + currentRowReg((colIndex.value * peCount.U + i.U) % n.U) := multiFMA.io.blockResult.bits(i) } - resetFlag := true.B stateReg := state.update } } is(state.update) { + multiFMA.io.reset := true.B io.currentRow.valid := false.B when(colIndex.inc()) { io.currentRow.valid := true.B @@ -306,25 +339,18 @@ class GEMMFMASingle( } is(state.done) { - doneFlag := true.B + io.done := true.B readyReg := true.B stateReg := state.idle } } - when(resetFlag) { - multiFMA.io.reset := true.B - resetFlag := false.B - }.otherwise { - multiFMA.io.reset := false.B - } - io.done := doneFlag } class GEMMSingleQueue( val m: Int, val k: Int, val n: Int, - val PECount: Int = 16, + val peCount: Int = 16, val gemmType: GEMMDataType.Type, val bufferSize: Int = 32 )( @@ -349,7 +375,7 @@ class GEMMSingleQueue( hasFlush = true ) ) - val gemm = Module(new GEMMFMASingle(m, k, n, PECount, gemmType)) + val gemm = Module(new GEMMFMASingle(m, k, n, peCount, gemmType)) gemm.io.matrixA <> io.matrixA gemm.io.matrixB <> io.matrixB currentBuffer.io.flush.get := io.flush @@ -361,12 +387,12 @@ class GEMMSingleQueue( // first use GEMMFMATotal to get Q and K, then use GEMMFMASingle to get Q*K^T // out one row of score matrix -class QKMulFMASingle( +class AttnScoresSingle( val m: Int, val k: Int, val n: Int, - val PECount1: Int = 16, - val PECount2: Int = 16, + val peCount1: Int = 16, + val peCount2: Int = 16, val gemmType: GEMMDataType.Type, val bufferSizeGemm: Int = 32 )( @@ -393,8 +419,8 @@ class QKMulFMASingle( io.done := false.B //use GEMMFMATotal to get Q and K - val qGen = Module(new GEMMFMATotal(m, k, n, PECount1, gemmType)) - val kGen = Module(new GEMMFMATotal(m, k, n, PECount1, gemmType)) + val qGen = Module(new GEMMFMATotal(m, k, n, peCount1, gemmType)) + val kGen = Module(new GEMMFMATotal(m, k, n, peCount1, gemmType)) qGen.io.matrixA <> io.inputToken qGen.io.matrixB <> io.weightQ kGen.io.matrixA <> io.inputToken @@ -402,7 +428,7 @@ class QKMulFMASingle( // when qGen and kGen are done, use GEMMFMASingle to get Q*K^T // Q: m * n, K: m * n -> Q*K^T: m * m - val QK_TMul = Module(new GEMMSingleQueue(m, n, m, PECount2, gemmType, bufferSizeGemm)) + val QK_TMul = Module(new GEMMSingleQueue(m, n, m, peCount2, gemmType, bufferSizeGemm)) QK_TMul.io.matrixA <> qGen.io.results val K_T = VecInit(Seq.fill(n)(VecInit(Seq.fill(m)(0.U(config.inputWidth.W))))) @@ -458,234 +484,561 @@ class QKMulFMASingle( } -// TODO: 优化,bug -//first use GEMMFMATotal to get Q and K, then use GEMMFMASingle to get Q*K^T -class QKMulFMA( - val m: Int, - val k: Int, - val n: Int, - val PECount1: Int = 16, - val PECount2: Int = 16, - val gemmType: GEMMDataType.Type, - val bufferSizeGemm: Int = 32 +class AttnScoresTotal( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int, + val gemmType: GEMMDataType.Type )( implicit config: DataWidthConfig) extends Module with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + }) - class QKGenderWarper( - val m: Int, - val k: Int, - val n: Int, - val PECount: Int = 16, - val gemmType: GEMMDataType.Type, - val bufferSize: Int - )( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - val io = IO(new Bundle { - val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) // 矩阵A - val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) // 矩阵B - val flush = Input(Bool()) - val outMatrix = Decoupled(new currentRowIndex(m, n)) - }) + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - val qkGenMul = Module(new GEMMFMASingle(m, k, n, PECount, gemmType)) - io.matrixA <> qkGenMul.io.matrixA - io.matrixB <> qkGenMul.io.matrixB - - val currentBuffer = Module( - new Queue( - new currentRowIndex(m, n), - entries = bufferSize, - pipe = true, - flow = false, - useSyncReadMem = false, - hasFlush = true - ) - ) + val readyReg = RegInit(true.B) + io.inputToken.ready := readyReg + io.weightQ.ready := readyReg + io.weightK.ready := readyReg + io.scores.valid := false.B + io.scores.bits := DontCare + + //use GEMMFMATotal to get Q and K + val qGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + val kGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + qGen.io.matrixA := DontCare + qGen.io.matrixB := DontCare + kGen.io.matrixA := DontCare + kGen.io.matrixB := DontCare + qGen.io.results.ready := false.B + kGen.io.results.ready := false.B + + // val Qreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + // val Kreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) - // hasFlush must be true - currentBuffer.io.flush.get := io.flush + // when qGen and kGen are done, use GEMMFMASingle to get Q*K^T + // Q: m * n, K: m * n -> Q*K^T: m * m + val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) + QK_TMul.io.matrixA := DontCare + QK_TMul.io.matrixB := DontCare + QK_TMul.io.results.ready := false.B - // ATTENTION: we assert the size of the buffer is huge enough to hold the current systolic group output - // we ignore the ready signal of the enq - currentBuffer.io.enq.bits := qkGenMul.io.currentRow.bits - currentBuffer.io.enq.valid := qkGenMul.io.currentRow.valid + // object state extends ChiselEnum { + // val idle, gen, mul, done = Value + // } + // val stateReg = RegInit(state.idle) + + // switch(stateReg) { + // is(state.idle) { + // when(dataValid) { + // readyReg := false.B + // stateReg := state.gen + // } + // } + // is(state.gen) { + // when(qGen.io.results.valid && kGen.io.results.valid) { + // // Qreg := qGen.io.results.bits + // // Kreg := kGen.io.results.bits + // stateReg := state.mul + // } + // } + // is(state.mul) { + // when(QK_TMul.io.results.valid) { + // stateReg := state.done + // } + // } + // is(state.done) { + // readyReg := true.B + // stateReg := state.idle + // } + // } - io.outMatrix <> currentBuffer.io.deq + when(dataValid) { + readyReg := false.B + qGen.io.matrixA <> io.inputToken + qGen.io.matrixB <> io.weightQ + kGen.io.matrixA <> io.inputToken + kGen.io.matrixB <> io.weightK + qGen.io.results.ready := true.B + kGen.io.results.ready := true.B + when(qGen.io.results.valid && kGen.io.results.valid) { + QK_TMul.io.matrixA.valid := qGen.io.results.valid + QK_TMul.io.matrixA.bits := qGen.io.results.bits + qGen.io.results.ready := true.B + QK_TMul.io.matrixB.valid := kGen.io.results.valid + QK_TMul.io.matrixB.bits := VecInit(kGen.io.results.bits.transpose.map(VecInit(_))) + kGen.io.results.ready := true.B + when(QK_TMul.io.results.valid) { + io.scores.valid := QK_TMul.io.results.valid + io.scores.bits := QK_TMul.io.results.bits + readyReg := true.B + } + } } +} +// QKGenWithReg: use two GEMMFMATotal to get Q and K +// whthin use Reg to store Q and K +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: Query: m * n +// output: Key: m * n +class QKGenWithReg( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { val io = IO(new Bundle { val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val score = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) - val resetBuffer = Input(Bool()) + val Query = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + val Key = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) }) val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + val readyReg = RegInit(true.B) io.inputToken.ready := readyReg io.weightQ.ready := readyReg io.weightK.ready := readyReg + io.Key.valid := false.B + io.Key.bits := DontCare + io.Query.valid := false.B + io.Query.bits := DontCare + + val qGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + val kGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + + val Qreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + val Kreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + + qGen.io.matrixA.bits := io.inputToken.bits + qGen.io.matrixA.valid := io.inputToken.valid + qGen.io.matrixB.bits := io.weightQ.bits + qGen.io.matrixB.valid := io.weightQ.valid + qGen.io.results.ready := true.B + + kGen.io.matrixA.bits := io.inputToken.bits + kGen.io.matrixA.valid := io.inputToken.valid + kGen.io.matrixB.bits := io.weightK.bits + kGen.io.matrixB.valid := io.weightK.valid + kGen.io.results.ready := true.B + + io.Query.valid := false.B + io.Query.bits := Qreg + io.Key.valid := false.B + io.Key.bits := Kreg + + // qGen.io.matrixA <> io.inputToken + // qGen.io.matrixB <> io.weightQ + // kGen.io.matrixA <> io.inputToken + // kGen.io.matrixB <> io.weightK + + // io.Query <> qGen.io.results + // io.Key <> kGen.io.results - // QKGen,Q: m * n, K: m * n - val qGen = Module(new QKGenderWarper(m, k, n, PECount1, gemmType, bufferSizeGemm)) - val kGen = Module(new QKGenderWarper(m, k, n, PECount1, gemmType, bufferSizeGemm)) + object state extends ChiselEnum { + val idle, gen, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.gen + } + } + is(state.gen) { + when(qGen.io.results.valid && kGen.io.results.valid) { + Qreg := qGen.io.results.bits + Kreg := kGen.io.results.bits + stateReg := state.done + } + } + is(state.done) { + readyReg := true.B + io.Query.valid := true.B + io.Key.valid := true.B + stateReg := state.idle + } + } +} + +// QKGen: use two GEMMFMATotal to get Q and K +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: Query: m * n +// output: Key: m * n +class QKGen( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val Query = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + val Key = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + val readyReg = RegInit(true.B) + io.inputToken.ready := readyReg + io.weightQ.ready := readyReg + io.weightK.ready := readyReg + io.Key.valid := false.B + io.Key.bits := DontCare + io.Query.valid := false.B + io.Query.bits := DontCare + + val qGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + val kGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + + // qGen.io.matrixA.bits := io.inputToken.bits + // qGen.io.matrixA.valid := io.inputToken.valid + // qGen.io.matrixB.bits := io.weightQ.bits + // qGen.io.matrixB.valid := io.weightQ.valid + // qGen.io.results.ready := true.B + + // kGen.io.matrixA.bits := io.inputToken.bits + // kGen.io.matrixA.valid := io.inputToken.valid + // kGen.io.matrixB.bits := io.weightK.bits + // kGen.io.matrixB.valid := io.weightK.valid + // kGen.io.results.ready := true.B + + // io.Query.valid := qGen.io.results.valid + // io.Query.bits := qGen.io.results.bits + // io.Key.valid := kGen.io.results.valid + // io.Key.bits := kGen.io.results.bits qGen.io.matrixA <> io.inputToken qGen.io.matrixB <> io.weightQ kGen.io.matrixA <> io.inputToken kGen.io.matrixB <> io.weightK - qGen.io.flush := io.resetBuffer - kGen.io.flush := io.resetBuffer + io.Query <> qGen.io.results + io.Key <> kGen.io.results - // // QKMul Q*K^T, Q: m * n, K: m * n -> m * m - // val Qrow = qGen.io.outMatrix.bits.value // one row of Q: 1 * n - // val Krow = kGen.io.outMatrix.bits.value // one row of K: 1 * n - // val QIndex = qGen.io.outMatrix.bits.index // the index of Q row - // val KIndex = kGen.io.outMatrix.bits.index // the index of K row + object state extends ChiselEnum { + val idle, gen, done = Value + } + val stateReg = RegInit(state.idle) - // 创建一个 MultiFMAMM 模块来计算 Q 的一行和 K 的多列的乘积结果 中间维度为n - // val multiFMA = Module(new MultiFMAMM(n, PECount2, gemmType)) + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.gen + } + } + is(state.gen) { + when(qGen.io.results.valid && kGen.io.results.valid) { + stateReg := state.done + } + } + is(state.done) { + readyReg := true.B + io.Query.valid := true.B + io.Key.valid := true.B + stateReg := state.idle + } + } +} - val qQueue = Module(new Queue(new currentRowIndex(m, n), bufferSizeGemm)) - val kQueue = Module(new Queue(new currentRowIndex(m, n), bufferSizeGemm)) +// QKMulTotalWithReg: use GEMMFMATotal to get scores +// input: Query: m * n +// input: Key: m * n +// output: scores: m * m +class QKMulTotalWithReg( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val Query = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val Key = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.Query.valid && io.Key.valid + + val readyReg = RegInit(true.B) + io.Query.ready := readyReg + io.Key.ready := readyReg + io.scores.valid := false.B + io.scores.bits := DontCare + + val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) + val scoresReg = Reg(Vec(m, Vec(m, UInt(config.outputWidth.W)))) - // 将生成的每一行数据存储到队列中 - qQueue.io.enq.bits := qGen.io.outMatrix.bits - qQueue.io.enq.valid := qGen.io.outMatrix.valid - kQueue.io.enq.bits := kGen.io.outMatrix.bits - kQueue.io.enq.valid := kGen.io.outMatrix.valid -// 创建一个 M*N 的寄存器组来保存所有的 Q 和 K 值 - val qMatrix = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) - val k_TMatrix = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + QK_TMul.io.matrixA.valid := io.Query.valid + QK_TMul.io.matrixA.bits := io.Query.bits + QK_TMul.io.matrixB.valid := io.Key.valid + QK_TMul.io.matrixB.bits := VecInit(io.Key.bits.transpose.map(VecInit(_))) + QK_TMul.io.results.ready := true.B + + // io.scores.valid := QK_TMul.io.results.valid + // io.scores.bits := QK_TMul.io.results.bits + io.scores.valid := false.B + io.scores.bits := scoresReg - // 状态机定义 object state extends ChiselEnum { - val idle, load, compute, done = Value + val idle, mul, done = Value } val stateReg = RegInit(state.idle) - // 计数器,用于跟踪 Q 和 K 的行数 - val qCounter = RegInit(0.U(log2Ceil(m).W)) - val kCounter = RegInit(0.U(log2Ceil(m).W)) - - // 创建一个 M*M 的寄存器组来保存所有的结果 - val scoreValue = RegInit(VecInit(Seq.fill(m)(VecInit(Seq.fill(m)(0.U(config.outputWidth.W)))))) - switch(stateReg) { is(state.idle) { - when(qQueue.io.enq.valid && kQueue.io.enq.valid) { - stateReg := state.load + when(dataValid) { + readyReg := false.B + stateReg := state.mul } } - - is(state.load) { - when(qQueue.io.deq.valid && kQueue.io.deq.valid) { - qMatrix(qQueue.io.deq.bits.index) := qQueue.io.deq.bits.value - for (i <- 0 until n) { - k_TMatrix(i)(kQueue.io.deq.bits.index) := kQueue.io.deq.bits.value(i) // 将 K 的值存储到转置后的 kMatrix 中 - } - qCounter := qCounter + 1.U - kCounter := kCounter + 1.U - qQueue.io.deq.ready := true.B - kQueue.io.deq.ready := true.B - when(qCounter === (m - 1).U && kCounter === (m - 1).U) { - stateReg := state.compute - } + is(state.mul) { + when(QK_TMul.io.results.valid) { + // for (i <- 0 until m) { + // for (j <- 0 until n) { + // // printf(p"QK_TMul.io.matrixA.bits($i)($j): ${QK_TMul.io.matrixA.bits(i)(j)}\n") + // // printf(p"QK_TMul.io.matrixB.bits($i)($j): ${QK_TMul.io.matrixB.bits(i)(j)}\n") + // printf(p"io.Query.bits($i)($j): ${io.Query.bits(i)(j)}\n") + // printf(p"io.Key.bits($i)($j): ${io.Key.bits(i)(j)}\n") + // } + // } + // for (i <- 0 until m) { + // for (j <- 0 until m) { + // printf(p"QK_TMul.io.results.bits($i)($j): ${QK_TMul.io.results.bits(i)(j)}\n") + // } + // } + scoresReg := QK_TMul.io.results.bits + stateReg := state.done } } - - is(state.compute) { - val multiFMA = Module(new GEMMFMA(m, n, m, PECount2, gemmType)) - multiFMA.io.matrixA := qMatrix - multiFMA.io.matrixB := k_TMatrix - io.score.bits := multiFMA.io.results - io.score.valid := multiFMA.io.done - stateReg := state.done + is(state.done) { + readyReg := true.B + io.scores.valid := true.B + stateReg := state.idle } + } +} + +// QKMulTotal: use GEMMFMATotal to get scores +// input: Query: m * n +// input: Key: m * n +// output: scores: m * m +class QKMulTotal( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val Query = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val Key = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.Query.valid && io.Key.valid + + val readyReg = RegInit(true.B) + io.Query.ready := readyReg + io.Key.ready := readyReg + io.scores.valid := false.B + io.scores.bits := DontCare + val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) + + QK_TMul.io.matrixA.valid := io.Query.valid + QK_TMul.io.matrixA.bits := io.Query.bits + QK_TMul.io.matrixB.valid := io.Key.valid + QK_TMul.io.matrixB.bits := VecInit(io.Key.bits.transpose.map(VecInit(_))) + QK_TMul.io.results.ready := true.B + + io.scores.valid := QK_TMul.io.results.valid + io.scores.bits := QK_TMul.io.results.bits + + object state extends ChiselEnum { + val idle, mul, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.mul + } + } + is(state.mul) { + when(QK_TMul.io.results.valid) { + // for (i <- 0 until m) { + // for (j <- 0 until n) { + // // printf(p"QK_TMul.io.matrixA.bits($i)($j): ${QK_TMul.io.matrixA.bits(i)(j)}\n") + // // printf(p"QK_TMul.io.matrixB.bits($i)($j): ${QK_TMul.io.matrixB.bits(i)(j)}\n") + // printf(p"io.Query.bits($i)($j): ${io.Query.bits(i)(j)}\n") + // printf(p"io.Key.bits($i)($j): ${io.Key.bits(i)(j)}\n") + // } + // } + // for (i <- 0 until m) { + // for (j <- 0 until m) { + // printf(p"QK_TMul.io.results.bits($i)($j): ${QK_TMul.io.results.bits(i)(j)}\n") + // } + // } + stateReg := state.done + } + } is(state.done) { - // 完成标志 - io.score.valid := true.B + readyReg := true.B + io.scores.valid := true.B stateReg := state.idle } } +} - when(io.resetBuffer) { - qCounter := 0.U - kCounter := 0.U - scoreValue := VecInit(Seq.fill(m)(VecInit(Seq.fill(m)(0.U(config.outputWidth.W))))) - stateReg := state.idle - } - // // 创建一个 M*N 的寄存器组来保存所有的 K 值 - // val kMatrix = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) - - // // 从队列中提取 Q 的一行和 K 的多列 - // val qRowFromQueue = qQueue.io.deq.bits.value - // val qIndexFromQueue = qQueue.io.deq.bits.index - - // val kColsFromQueue = Reg(Vec(PECount2, Vec(n, UInt(config.inputWidth.W)))) - // val kIndexFromQueue = Reg(Vec(PECount2, UInt(log2Ceil(m).W))) - - // // 计数器,用于跟踪 K 的列数 - // val kCounter = RegInit(0.U(log2Ceil(n / PECount2).W)) - - // // 当 K 队列中有足够的列时,提取 K 的多列 - // when(kQueue.io.deq.valid && kCounter < PECount2.U) { - // kColsFromQueue(kCounter) := kQueue.io.deq.bits.value - // kIndexFromQueue(kCounter) := kQueue.io.deq.bits.index - // kMatrix(kQueue.io.deq.bits.index) := kQueue.io.deq.bits.value - // kCounter := kCounter + 1.U - // kQueue.io.deq.ready := true.B - // }.otherwise { - // kQueue.io.deq.ready := false.B - // } +// AttnScores: use QKGen to get Q and K, then use QKMulTotal to get scores +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: scores: m * m +class AttnScores( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + }) - // // 当 K 队列中有足够的列时,进行矩阵乘法 - // when(kCounter === PECount2.U) { - // multiFMA.io.matrixA_row := qRowFromQueue - // multiFMA.io.matrixB_cols := kColsFromQueue - // kCounter := 0.U - // } + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - // // 连接结果和有效标志 - // val scoreValue = RegInit(VecInit(Seq.fill(m)(VecInit(Seq.fill(m)(0.U(config.outputWidth.W)))))) - // for (i <- 0 until PECount2) { - // when(multiFMA.io.valids(i)) { - // scoreValue(qIndexFromQueue)(kIndexFromQueue(i)) := multiFMA.io.results(i) - // } - // } + val readyReg = RegInit(true.B) + io.inputToken.ready := readyReg + io.weightQ.ready := readyReg + io.weightK.ready := readyReg - // io.score.bits := scoreValue - // io.score.valid := qQueue.io.deq.valid && kQueue.io.deq.valid + io.scores.valid := false.B + io.scores.bits := DontCare - // // 当 qQueue 继续有值时,继续处理 - // when(qQueue.io.deq.valid && kQueue.io.deq.valid) { - // qQueue.io.deq.ready := true.B - // }.otherwise { - // qQueue.io.deq.ready := false.B - // } + val QKGen = Module(new QKGen(m, k, n, peCount, gemmType)) - // when(io.resetBuffer) { - // kCounter := 0.U - // scoreValue := VecInit(Seq.fill(m)(VecInit(Seq.fill(m)(0.U(config.outputWidth.W))))) - // } + QKGen.io.inputToken.valid := io.inputToken.valid + QKGen.io.inputToken.bits := io.inputToken.bits + QKGen.io.weightQ.valid := io.weightQ.valid + QKGen.io.weightQ.bits := io.weightQ.bits + QKGen.io.weightK.valid := io.weightK.valid + QKGen.io.weightK.bits := io.weightK.bits + QKGen.io.Query.ready := true.B + QKGen.io.Key.ready := true.B - // // final result idx - // val rowIdx = RegInit(0.U(log2Ceil(m / PECount2).W)) - // val colIdx = RegInit(0.U(log2Ceil(m / PECount2).W)) - // val resValid = RegInit(false.B) - // io.score.valid := resValid + // val QueryReg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + // val KeyReg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) - // io.score.bits := scoreValue + val QKMul = Module(new QKMulTotalWithReg(m, n, peCount, gemmType)) - // when(resValid && io.score.ready) { - // resValid := false.B - // } + val scoresReg = Reg(Vec(m, Vec(m, UInt(config.inputWidth.W)))) + + QKMul.io.Query.valid := QKGen.io.Query.valid + QKMul.io.Query.bits := QKGen.io.Query.bits + // QKMul.io.Query.bits := QueryReg + QKMul.io.Key.valid := QKGen.io.Key.valid + QKMul.io.Key.bits := QKGen.io.Key.bits + // QKMul.io.Key.bits := KeyReg + QKMul.io.scores.ready := true.B + + // io.scores.valid := QKMul.io.scores.valid + // io.scores.bits := QKMul.io.scores.bits + io.scores.valid := false.B + io.scores.bits := scoresReg + + object state extends ChiselEnum { + val idle, gen, mul, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.gen + } + } + is(state.gen) { + printf(p"gen:\n") + when(QKGen.io.Query.valid && QKGen.io.Key.valid) { + // QueryReg := QKGen.io.Query.bits + // KeyReg := QKGen.io.Key.bits + // for (i <- 0 until m) { + // for (j <- 0 until n) { + // // printf(p"QueryReg($i)($j): ${QueryReg(i)(j)}\n") + // // printf(p"KeyReg($i)($j): ${KeyReg(i)(j)}\n") + // printf(p"QKGen.io.Query.bits($i)($j): ${QKGen.io.Query.bits(i)(j)}\n") + // printf(p"QKGen.io.Key.bits($i)($j): ${QKGen.io.Key.bits(i)(j)}\n") + // } + // } + stateReg := state.mul + } + } + is(state.mul) { + printf(p"mul:\n") + when(QKMul.io.scores.valid) { + // for (i <- 0 until m) { + // for (j <- 0 until n) { + // printf(p"QKMul.io.Query.bits($i)($j): ${QKMul.io.Query.bits(i)(j)}\n") + // printf(p"QKMul.io.Key.bits($i)($j): ${QKMul.io.Key.bits(i)(j)}\n") + // } + // } + // for (i <- 0 until m) { + // for (j <- 0 until m) { + // printf(p"QKMul.io.scores.bits($i)($j): ${QKMul.io.scores.bits(i)(j)}\n") + // } + // } + scoresReg := QKMul.io.scores.bits + stateReg := state.done + } + } + is(state.done) { + readyReg := true.B + io.scores.valid := true.B + stateReg := state.idle + } + } } + diff --git a/src/test/scala/kernel/alu/GemmFMATest.scala b/src/test/scala/kernel/alu/GemmFMATest.scala index f501539..c5659b1 100644 --- a/src/test/scala/kernel/alu/GemmFMATest.scala +++ b/src/test/scala/kernel/alu/GemmFMATest.scala @@ -187,46 +187,67 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") } } - private def testMultiFMAMM[T: Numeric: ClassTag]( - dut: MultiFMAMM + + private def testMultiFMA[T: Numeric: ClassTag]( + dut: MultiFMA )( implicit config: DataWidthConfig ): Unit = { val k = dut.k - val PECount = dut.PECount + val peCount = dut.peCount val gemmType = dut.gemmType + // val fixedMatrix = Array( + // Array(4, 2, 3, 1) + // ) + // val fixedMatrix2 = Array( + // Array(4, 2, 3, 1), + // Array(0, 5, 1, 3), + // Array(4, 2, 1, 0), + // Array(0, 3, 1, 3) + // ) + // val matrixA_row = fixedMatrix + // val matrixB_cols = fixedMatrix2 val matrixA_row = matInit[T](1, k) - val matrixB_cols = matInit[T](k, PECount) + val matrixB_cols = matInit[T](k, peCount) + val expectedResults = mmul(matrixA_row, matrixB_cols) printmat(matrixA_row) printmat(matrixB_cols) printmat(expectedResults) - // 初始化输入 + dut.io.reset.poke(true.B) dut.clock.step(1) dut.io.reset.poke(false.B) - // 逐元素输入数据 - for (i <- matrixA_row(0).indices) { - dut.io.matrixA_row(i).poke(toBinaryBigInt(matrixA_row(0)(i)).U) - for (j <- 0 until PECount) { - dut.io.matrixB_cols(j)(i).poke(toBinaryBigInt(matrixB_cols(i)(j)).U) + if (dut.io.matrixA_row.ready.peekBoolean() && dut.io.matrixB_cols.ready.peekBoolean()) { + println("matrixA_row and matrixB_cols are ready") + dut.io.matrixA_row.valid.poke(true.B) + dut.io.matrixB_cols.valid.poke(true.B) + for (i <- matrixA_row(0).indices) { + for (j <- 0 until peCount) { + dut.io.matrixA_row.bits(i).poke(toBinaryBigInt(matrixA_row(0)(i)).U) + dut.io.matrixB_cols.bits(i)(j).poke(toBinaryBigInt(matrixB_cols(i)(j)).U) + } } + } else { + dut.io.matrixA_row.valid.poke(false.B) + dut.io.matrixB_cols.valid.poke(false.B) } - while (!dut.io.valids.forall(_.peekBoolean())) { + while (!dut.io.blockResult.valid.peekBoolean()) { dut.clock.step() } + dut.io.blockResult.ready.poke(true.B) + val precision = 0.001f var invalidcnt = 0 - for (i <- 0 until PECount) { - val outBigInt = dut.io.results(i).peekInt() + for (i <- 0 until peCount) { + val outBigInt = dut.io.blockResult.bits(i).peekInt() val out = fromBinaryBigInt[T](outBigInt) val expected = expectedResults(0)(i) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { case c if c == classOf[Float] => math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision @@ -244,70 +265,8 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe printmat(Array(Array(expected))) invalidcnt += 1 } - dut.io.valids(i).expect(true.B) - } - - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") - } - - private def testGEMMFMA[T: Numeric: ClassTag]( - dut: GEMMFMA - )( - implicit config: DataWidthConfig - ): Unit = { - val m = dut.m - val k = dut.k - val n = dut.n - val PECount = dut.PECount - val gemmType = dut.gemmType - - val matrixA = matInit[T](m, k) - val matrixB = matInit[T](k, n) - val expectedResults = mmul(matrixA, matrixB) - - // 初始化输入 - for (row <- 0 until m) { - for (col <- 0 until n) { - for (i <- 0 until k) { - dut.io.matrixA(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) - dut.io.matrixB(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) - } - } - } - - while (!dut.io.done.peekBoolean()) { - dut.clock.step() } - val precision = 0.001f - var invalidcnt = 0 - for (row <- 0 until m) { - for (col <- 0 until n) { - val outBigInt = dut.io.results(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - }) - // printmat(Array(Array(out))) - // printmat(Array(Array(expected))) - if (isInvalid) { - println("Error: ") - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } if (invalidcnt == 0) println("Verification passed!") else println(s"Verification failed with $invalidcnt errors.") } @@ -320,13 +279,15 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val m = dut.m val k = dut.k val n = dut.n - val PECount = dut.PECount + val peCount = dut.peCount val gemmType = dut.gemmType val matrixA = matInit[T](m, k) val matrixB = matInit[T](k, n) val expectedResults = mmul(matrixA, matrixB) - // printmat(expectedResults) + printmat(matrixA) + printmat(matrixB) + printmat(expectedResults) if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { println("matrixA and matrixB are ready") @@ -349,6 +310,8 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe dut.clock.step() } + dut.io.results.ready.poke(true.B) + val precision = 0.001f var invalidcnt = 0 @@ -370,7 +333,7 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe // printmat(Array(Array(out))) // printmat(Array(Array(expected))) if (isInvalid) { - println("Error: ") + println("Error: row: " + row + " col: " + col) printmat(Array(Array(out))) printmat(Array(Array(expected))) invalidcnt += 1 @@ -389,7 +352,7 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val m = dut.m val k = dut.k val n = dut.n - val PECount = dut.PECount + val peCount = dut.peCount val gemmType = dut.gemmType val matrixA = matInit[T](m, k) @@ -525,27 +488,24 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe if (invalidcnt == 0) println("Verification passed!") else println(s"Verification failed with $invalidcnt errors.") } - private def testQKMulFMASingle[T: Numeric: ClassTag]( - dut: QKMulFMASingle + + private def testQKGen[T: Numeric: ClassTag]( + dut: QKGen )( implicit config: DataWidthConfig ): Unit = { val m = dut.m val k = dut.k val n = dut.n - val PECount1 = 4 - val PECount2 = 4 val gemmType = dut.gemmType val inputToken = matInit[T](m, k) val weightQ = matInit[T](k, n) val weightK = matInit[T](k, n) - val W_q = mmul(inputToken, weightQ) - printmat(W_q) - val W_k = mmul(inputToken, weightK) - printmat(W_k.transpose) - val expectedResults = mmul(W_q, W_k.transpose) // W_q * W_k^T - printmat(expectedResults) + val Query = mmul(inputToken, weightQ) + printmat(Query) + val Key = mmul(inputToken, weightK) + printmat(Key) if ( dut.io.inputToken.ready.peekBoolean() && dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() @@ -569,131 +529,614 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe dut.io.weightK.valid.poke(false.B) } - dut.io.scoreRow.ready.poke(true.B) + while (!(dut.io.Key.valid.peekBoolean() && dut.io.Query.valid.peekBoolean())) { + dut.clock.step() + } + + dut.io.Key.ready.poke(true.B) + dut.io.Query.ready.poke(true.B) val precision = 0.001f var invalidcnt = 0 - while (!dut.io.done.peekBoolean()) { - if (dut.io.scoreRow.valid.peekBoolean()) { - val scoreRowIndex = dut.io.scoreRow.bits.index.peekInt() - println("scoreRow index: " + scoreRowIndex) - for (i <- 0 until n) { - val outBigInt = dut.io.scoreRow.bits.value(i).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(scoreRowIndex.toInt)(i) - println("i: " + i) + for (row <- 0 until m) { + for (col <- 0 until n) { + val outBigInt = dut.io.Query.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = Query(row)(col) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: row: " + row + " col: " + col) printmat(Array(Array(out))) printmat(Array(Array(expected))) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + invalidcnt += 1 + } + } + } - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + for (row <- 0 until m) { + for (col <- 0 until n) { + val outBigInt = dut.io.Key.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = Key(row)(col) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - }) - if (isInvalid) { - println("Error: ") - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: row: " + row + " col: " + col) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testQKGenWithReg[T: Numeric: ClassTag]( + dut: QKGenWithReg + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val Query = mmul(inputToken, weightQ) + printmat(Query) + val Key = mmul(inputToken, weightK) + printmat(Key) + + if ( + dut.io.inputToken.ready.peekBoolean() && dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() + ) { + println("inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for (row <- 0 until m) { + for (col <- 0 until n) { + for (i <- 0 until k) { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) } } } + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + + while (!(dut.io.Key.valid.peekBoolean() && dut.io.Query.valid.peekBoolean())) { dut.clock.step() } + dut.io.Key.ready.poke(true.B) + dut.io.Query.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for (row <- 0 until m) { + for (col <- 0 until n) { + val outBigInt = dut.io.Query.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = Query(row)(col) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: row: " + row + " col: " + col) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + + for (row <- 0 until m) { + for (col <- 0 until n) { + val outBigInt = dut.io.Key.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = Key(row)(col) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: row: " + row + " col: " + col) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + if (invalidcnt == 0) println("Verification passed!") else println(s"Verification failed with $invalidcnt errors.") } - "QKMulFMASingle " should "compute fxp matrix multiplication" in { - implicit val config: DataWidthConfig = FxpConfig - test(new QKMulFMASingle(m = 4, k = 4, n = 4, PECount1 = 4, PECount2 = 4, gemmType = GEMMDataType.Fxp)) - .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - testQKMulFMASingle[Int](dut) + private def testQKMulTotal[T: Numeric: ClassTag]( + dut: QKMulTotal + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + // println("m: " + m + " n: " + n) + + // val fixedMatrix = Array( + // Array(4, -1, 3, 1), + // Array(0, 5, -3, 3), + // Array(4, -2, 4, 0), + // Array(0, 3, -1, 3) + // ) + // val Query = fixedMatrix + // val Key = fixedMatrix + val Query = matInit[T](m, n) + val Key = matInit[T](m, n) + val expectedResults = mmul(Query, Key.transpose) + + println("Query:") + printmat(Query) + println("Key:") + printmat(Key) + println("expectedResults:") + printmat(expectedResults) + + if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { + println(" Query and Key are ready") + dut.io.Query.valid.poke(true.B) + dut.io.Key.valid.poke(true.B) + for (row <- 0 until m) { + for (col <- 0 until n) { + dut.io.Query.bits(row)(col).poke(toBinaryBigInt(Query(row)(col)).U) + dut.io.Key.bits(row)(col).poke(toBinaryBigInt(Key(row)(col)).U) + } } + } else { + dut.io.Query.valid.poke(false.B) + dut.io.Key.valid.poke(false.B) + } + + while (!dut.io.scores.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for (row <- 0 until m) { + for (col <- 0 until m) { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: row: " + row + " col: " + col) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") } - // "GEMMSingleQueue " should "compute fxp matrix multiplication" in { + private def testQKMulTotalWithReg[T: Numeric: ClassTag]( + dut: QKMulTotalWithReg + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + // println("m: " + m + " n: " + n) + + // val fixedMatrix = Array( + // Array(4, -1, 3, 1), + // Array(0, 5, -3, 3), + // Array(4, -2, 4, 0), + // Array(0, 3, -1, 3) + // ) + // val Query = fixedMatrix + // val Key = fixedMatrix + val Query = matInit[T](m, n) + val Key = matInit[T](m, n) + val expectedResults = mmul(Query, Key.transpose) + + println("Query:") + printmat(Query) + println("Key:") + printmat(Key) + println("expectedResults:") + printmat(expectedResults) + + if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { + println(" Query and Key are ready") + dut.io.Query.valid.poke(true.B) + dut.io.Key.valid.poke(true.B) + for (row <- 0 until m) { + for (col <- 0 until n) { + dut.io.Query.bits(row)(col).poke(toBinaryBigInt(Query(row)(col)).U) + dut.io.Key.bits(row)(col).poke(toBinaryBigInt(Key(row)(col)).U) + } + } + } else { + dut.io.Query.valid.poke(false.B) + dut.io.Key.valid.poke(false.B) + } + + while (!dut.io.scores.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for (row <- 0 until m) { + for (col <- 0 until m) { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: row: " + row + " col: " + col) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testAttnScores[T: Numeric: ClassTag]( + dut: AttnScores + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val Query = mmul(inputToken, weightQ) + printmat(Query) + val Key = mmul(inputToken, weightK) + printmat(Key.transpose) + val expectedResults = mmul(Query, Key.transpose) // Query * Key^T + printmat(expectedResults) + + if ( + dut.io.inputToken.ready.peekBoolean() && dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() + ) { + println("inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for (row <- 0 until m) { + for (col <- 0 until n) { + for (i <- 0 until k) { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + } + } + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + + while (!dut.io.scores.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for (row <- 0 until m) { + for (col <- 0 until m) { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: row: " + row + " col: " + col) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + private def testAttnScoresTotal[T: Numeric: ClassTag]( + dut: AttnScoresTotal + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val Query = mmul(inputToken, weightQ) + printmat(Query) + val Key = mmul(inputToken, weightK) + printmat(Key.transpose) + val expectedResults = mmul(Query, Key.transpose) // Query * Key^T + printmat(expectedResults) + + if ( + dut.io.inputToken.ready.peekBoolean() && dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() + ) { + println("inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for (row <- 0 until m) { + for (col <- 0 until n) { + for (i <- 0 until k) { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + } + } + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + + while (!dut.io.scores.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for (row <- 0 until m) { + for (col <- 0 until m) { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: row: " + row + " col: " + col) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + // ===--::--=== + // below tests ERROR + // ===--::--=== + + // "AttnScoresTotal " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new GEMMSingleQueue(m = 4, k = 8, n = 12, PECount = 4, gemmType = GEMMDataType.Fxp)) + // test(new AttnScoresTotal(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testGEMMSingleQueue[Int](dut) + // testAttnScoresTotal[Int](dut) // } // } - // "GEMMFMATotal " should "compute fxp matrix multiplication" in { + + // "AttnScores " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new GEMMFMATotal(m = 4, k = 8, n = 12, PECount = 4, gemmType = GEMMDataType.Fxp)) + // test(new AttnScores(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testGEMMFMATotal[Int](dut) + // testAttnScores[Int](dut) // } // } - // "GEMMFMASingle " should "compute fp32 matrix multiplication" in { + // "AttnScores " should "compute fp32 matrix multiplication" in { // implicit val config: DataWidthConfig = Fp32Config - // test(new GEMMFMASingle(m = 4, k = 8, n = 12, PECount = 4, gemmType = GEMMDataType.Fp32)) + // test(new AttnScores(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fp32)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testGEMMFMASingle[Float](dut) + // testAttnScores[Float](dut) // } // } - // "GEMMFMASingle " should "compute fxp matrix multiplication" in { + // ===--::--=== + // below tests PASS + // ===--::--=== + + // "QKMulTotal " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new GEMMFMASingle(m = 4, k = 8, n = 12, PECount = 4, gemmType = GEMMDataType.Fxp)) + // test(new QKMulTotal(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testGEMMFMASingle[Int](dut) + // testQKMulTotal[Int](dut) // } // } - // "MultiFMAMM " should "compute fp32 dot product" in { - // implicit val config: DataWidthConfig = Fp32Config - // test(new MultiFMAMM(k = 4, PECount = 16, gemmType = GEMMDataType.Fp32)) + // "QKMulTotalWithReg " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKMulTotalWithReg(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testMultiFMAMM[Float](dut) + // testQKMulTotalWithReg[Int](dut) // } // } - // "MultiFMAMM " should "compute fp64 dot product" in { - // implicit val config: DataWidthConfig = Fp64Config - // test(new MultiFMAMM(k = 4, PECount = 16, gemmType = GEMMDataType.Fp64)) + // "QKGen " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKGen(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testMultiFMAMM[Double](dut) + // testQKGen[Int](dut) // } // } - // "MultiFMAMM " should "compute fxp dot product" in { + // "QKGenWithReg " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new MultiFMAMM(k = 4, PECount = 16, gemmType = GEMMDataType.Fxp)) + // test(new QKGenWithReg(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testMultiFMAMM[Int](dut) + // testQKGenWithReg[Int](dut) // } // } - // "GEMMFMA " should "compute fp32 matrix multiplication" in { + // "GEMMSingleQueue " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMSingleQueue[Int](dut) + // } + // } + + // "GEMMSingleQueue " should "compute fp32 matrix multiplication" in { // implicit val config: DataWidthConfig = Fp32Config - // test(new GEMMFMA(m = 4, k = 4, n = 16, PECount = 16, gemmType = GEMMDataType.Fp32)) + // test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fp32)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testGEMMFMA[Float](dut) + // testGEMMSingleQueue[Float](dut) // } // } - // "GEMMFMA " should "compute fp64 matrix multiplication" in { - // implicit val config: DataWidthConfig = Fp64Config - // test(new GEMMFMA(m = 4, k = 4, n = 16, PECount = 16, gemmType = GEMMDataType.Fp64)) + // "GEMMFMATotal " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new GEMMFMATotal(m = 4, k = 4, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testGEMMFMA[Double](dut) + // testGEMMFMATotal[Int](dut) // } // } - // "GEMMFMA " should "compute fxp matrix multiplication" in { + // "GEMMFMATotal " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new GEMMFMATotal(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMATotal[Float](dut) + // } + // } + + // "GEMMFMASingle " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new GEMMFMASingle(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMASingle[Float](dut) + // } + // } + + // "GEMMFMASingle " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new GEMMFMASingle(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMFMASingle[Int](dut) + // } + // } + + // "MultiFMA " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new GEMMFMA(m = 4, k = 4, n = 16, PECount = 16, gemmType = GEMMDataType.Fxp)) + // test(new MultiFMA(k = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testMultiFMA[Int](dut) + // } + // } + + // "MultiFMA " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new MultiFMA( k = 4, peCount = 4, gemmType = GEMMDataType.Fp32)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testGEMMFMA[Int](dut) + // testMultiFMA[Float](dut) // } // } } From 3d6b734da0f01b745dea943cf0b205cdffebde2f Mon Sep 17 00:00:00 2001 From: pyfirstcsh Date: Sun, 5 Jan 2025 21:35:05 +0800 Subject: [PATCH 04/10] add OutValue&Single --- src/main/scala/kernel/alu/GemmFMA.scala | 145 +++++++++++++++++ src/test/scala/kernel/alu/GemmFMATest.scala | 164 ++++++++++++++++++++ 2 files changed, 309 insertions(+) diff --git a/src/main/scala/kernel/alu/GemmFMA.scala b/src/main/scala/kernel/alu/GemmFMA.scala index 341e752..be78cd6 100644 --- a/src/main/scala/kernel/alu/GemmFMA.scala +++ b/src/main/scala/kernel/alu/GemmFMA.scala @@ -1042,3 +1042,148 @@ class AttnScores( } +// OutValue: get the final output value +// input: AttnWeights: m * m +// input: Value: m * n +// output: AttnOut: m * n +class OutValue( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val AttnWeights = Flipped(Decoupled(Vec(m, Vec(m, UInt(config.inputWidth.W))))) + val Value = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val AttnOut = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.AttnWeights.valid && io.Value.valid + + val readyReg = RegInit(true.B) + io.AttnWeights.ready := readyReg + io.Value.ready := readyReg + io.AttnOut.valid := false.B + io.AttnOut.bits := DontCare + + val ValueMul = Module(new GEMMFMATotal(m, m, n, peCount, gemmType)) + + ValueMul.io.matrixA.valid := io.AttnWeights.valid + ValueMul.io.matrixA.bits := io.AttnWeights.bits + ValueMul.io.matrixB.valid := io.Value.valid + ValueMul.io.matrixB.bits := io.Value.bits + ValueMul.io.results.ready := true.B + + io.AttnOut.valid := ValueMul.io.results.valid + io.AttnOut.bits := ValueMul.io.results.bits + +} + +// OutValue: get the final output value +// input: one row of AttnWeights: 1 * m ,total m rows +// input: Value: m * n +// output: one row of AttnOut: 1 * n ,total m rows +// output: done: Bool +class OutValueSingle( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val currentAttnW = Flipped(Decoupled(new currentRowIndex(m, m))) + val Value = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val currentAttnO = Decoupled(new currentRowIndex(m, n)) + val done = Output(Bool()) + }) + + val dataValid = io.currentAttnW.valid && io.Value.valid + + val readyReg = RegInit(true.B) + io.currentAttnW.ready := readyReg + io.Value.ready := readyReg + io.currentAttnO.valid := false.B + io.currentAttnO.bits := DontCare + io.done := false.B + + val multiFMA = Module(new MultiFMA(m, peCount, gemmType)) + + val rowIndex = Counter(m) + val colIndex = Counter(n / peCount) + + multiFMA.io.matrixA_row.valid := io.currentAttnW.valid + multiFMA.io.matrixA_row.bits := io.currentAttnW.bits.value + + multiFMA.io.matrixB_cols.valid := io.Value.valid + multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(m) { j => + VecInit(Seq.tabulate(peCount) { i => + io.Value.bits(j)((colIndex.value * peCount.U + i.U) % n.U) + }) + }) //m * peCount size block of Value + + multiFMA.io.reset := false.B + multiFMA.io.blockResult.ready := true.B + + val currentRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) + + object state extends ChiselEnum { + val idle, compute, update, done = Value + } + + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + // readyReg := false.B + stateReg := state.compute + } + } + + is(state.compute) { + multiFMA.io.reset := false.B + // printf(p"multiFMA.io.matrixA_row.bits: ${multiFMA.io.matrixA_row.bits}\n") + // printf(p"multiFMA.io.matrixB_cols.bits: ${multiFMA.io.matrixB_cols.bits}\n") + + when(multiFMA.io.blockResult.valid) { + for (i <- 0 until peCount) { + currentRowReg(colIndex.value * peCount.U + i.U) := multiFMA.io.blockResult.bits(i) + } + stateReg := state.update + } + } + + is(state.update) { + multiFMA.io.reset := true.B + io.currentAttnO.valid := false.B + when(colIndex.inc()) { + io.currentAttnO.valid := true.B + io.currentAttnO.bits.index := rowIndex.value + io.currentAttnO.bits.value := currentRowReg + // readyReg := true.B + // io.currentAttnO.ready := true.B + // io.Value.ready := true.B + when(rowIndex.inc()) { + stateReg := state.done + }.otherwise { + stateReg := state.compute + // wait for next row of AttnWeights + } + }.otherwise { + stateReg := state.compute + } + + } + is(state.done) { + io.done := true.B + readyReg := true.B + stateReg := state.idle + } + } +} diff --git a/src/test/scala/kernel/alu/GemmFMATest.scala b/src/test/scala/kernel/alu/GemmFMATest.scala index c5659b1..87259d9 100644 --- a/src/test/scala/kernel/alu/GemmFMATest.scala +++ b/src/test/scala/kernel/alu/GemmFMATest.scala @@ -934,6 +934,7 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe if (invalidcnt == 0) println("Verification passed!") else println(s"Verification failed with $invalidcnt errors.") } + private def testAttnScoresTotal[T: Numeric: ClassTag]( dut: AttnScoresTotal )( @@ -1012,10 +1013,157 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe else println(s"Verification failed with $invalidcnt errors.") } + private def testOutValue[T: Numeric: ClassTag]( + dut: OutValue + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + + val AttnWeights = matInit[T](m, m) + val Value = matInit[T](m, n) + val expectedResults = mmul(AttnWeights, Value) + + if (dut.io.AttnWeights.ready.peekBoolean() && dut.io.Value.ready.peekBoolean()) { + println("AttnWeights and Value are ready") + dut.io.AttnWeights.valid.poke(true.B) + dut.io.Value.valid.poke(true.B) + for (row <- 0 until m) { + for (col <- 0 until n) { + for (i <- 0 until m) { + dut.io.AttnWeights.bits(row)(i).poke(toBinaryBigInt(AttnWeights(row)(i)).U) + dut.io.Value.bits(i)(col).poke(toBinaryBigInt(Value(i)(col)).U) + } + } + } + } else { + dut.io.AttnWeights.valid.poke(false.B) + dut.io.Value.valid.poke(false.B) + } + + while (!dut.io.AttnOut.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.AttnOut.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for (row <- 0 until m) { + for (col <- 0 until n) { + val outBigInt = dut.io.AttnOut.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + + }) + if (isInvalid) { + println("Error: row: " + row + " col: " + col) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testOutValueSingle[T: Numeric: ClassTag]( + dut: OutValueSingle + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + + val AttnWeights = matInit[T](m, m) + val Value = matInit[T](m, n) + val expectedResults = mmul(AttnWeights, Value) + printmat(AttnWeights) + printmat(Value) + printmat(expectedResults) + + val precision = 0.001f + var invalidcnt = 0 + for (index <- 0 until m) { + println("index: " + index) + if (dut.io.currentAttnW.ready.peekBoolean() && dut.io.Value.ready.peekBoolean()) { + println("currentAttnW index :" + index + " and Value are ready") + + dut.io.currentAttnW.valid.poke(true.B) + dut.io.Value.valid.poke(true.B) + for (i <- 0 until m) { + dut.io.currentAttnW.bits.value(i).poke(toBinaryBigInt(AttnWeights(index)(i)).U) + for (j <- 0 until n) { + dut.io.Value.bits(i)(j).poke(toBinaryBigInt(Value(i)(j)).U) + } + } + + } else { + dut.io.currentAttnW.valid.poke(false.B) + dut.io.Value.valid.poke(false.B) + } + while (!dut.io.currentAttnO.valid.peekBoolean()) { + dut.io.currentAttnO.ready.poke(false.B) + dut.clock.step() + } + + dut.io.currentAttnO.ready.poke(true.B) + + val currentRowIndex = dut.io.currentAttnO.bits.index.peekInt() + // println("currentRow index:" + currentRowIndex + " expected: " + index) + for (i <- 0 until n) { + val outBigInt = dut.io.currentAttnO.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(currentRowIndex.toInt)(i) + + val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + }) + if (isInvalid) { + println("Error: " + i) + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + invalidcnt += 1 + } + } + dut.clock.step() + + } + + // while (!dut.io.done.peekBoolean()) { + // dut.clock.step() + // } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + // ===--::--=== // below tests ERROR // ===--::--=== + // "AttnScoresTotal " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig // test(new AttnScoresTotal(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) @@ -1044,6 +1192,22 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe // below tests PASS // ===--::--=== + // "OutValueSingle " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new OutValueSingle(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testOutValueSingle[Int](dut) + // } + // } + + // "OutValue " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new OutValue(m = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testOutValue[Int](dut) + // } + // } + // "QKMulTotal " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig // test(new QKMulTotal(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) From 82f2ef6efbdd2a31b7ec6bf39e53e53a8cfd6b28 Mon Sep 17 00:00:00 2001 From: pyfirstcsh Date: Mon, 6 Jan 2025 22:59:29 +0800 Subject: [PATCH 05/10] fix && AttnScores pass --- src/main/scala/kernel/alu/AttnScores.scala | 400 ++++++ src/main/scala/kernel/alu/GemmFMA.scala | 820 +----------- src/main/scala/kernel/alu/OutValue.scala | 177 +++ .../scala/kernel/alu/AttnScoresTest.scala | 281 ++++ src/test/scala/kernel/alu/GemmFMATest.scala | 1129 ++--------------- src/test/scala/kernel/alu/OutValueTest.scala | 162 +++ src/test/scala/kernel/alu/utils.scala | 214 ++++ 7 files changed, 1313 insertions(+), 1870 deletions(-) create mode 100644 src/main/scala/kernel/alu/AttnScores.scala create mode 100644 src/main/scala/kernel/alu/OutValue.scala create mode 100644 src/test/scala/kernel/alu/AttnScoresTest.scala create mode 100644 src/test/scala/kernel/alu/OutValueTest.scala create mode 100644 src/test/scala/kernel/alu/utils.scala diff --git a/src/main/scala/kernel/alu/AttnScores.scala b/src/main/scala/kernel/alu/AttnScores.scala new file mode 100644 index 0000000..824851d --- /dev/null +++ b/src/main/scala/kernel/alu/AttnScores.scala @@ -0,0 +1,400 @@ +package kernel.alu + +import chisel3._ +import chisel3.util._ +import kernel.alu.GEMMDataType +import kernel.alu.DataWidthConfig +import kernel.utils.DebugLog +import kernel.deprecated.PE + +// QKGenWithReg: use two GEMMFMATotal to get Q and K +// whthin use Reg to store Q and K +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: Query: m * n +// output: Key: m * n +class QKGenWithReg( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val Query = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + val Key = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + val readyReg = RegInit(true.B) + io.inputToken.ready := readyReg + io.weightQ.ready := readyReg + io.weightK.ready := readyReg + io.Key.valid := false.B + io.Key.bits := DontCare + io.Query.valid := false.B + io.Query.bits := DontCare + + val qGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + val kGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + + val Qreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + val Kreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + + qGen.io.matrixA.bits := io.inputToken.bits + qGen.io.matrixA.valid := io.inputToken.valid + qGen.io.matrixB.bits := io.weightQ.bits + qGen.io.matrixB.valid := io.weightQ.valid + qGen.io.results.ready := false.B + + kGen.io.matrixA.bits := io.inputToken.bits + kGen.io.matrixA.valid := io.inputToken.valid + kGen.io.matrixB.bits := io.weightK.bits + kGen.io.matrixB.valid := io.weightK.valid + kGen.io.results.ready := false.B + + object state extends ChiselEnum { + val idle, gen, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.gen + } + } + is(state.gen) { + qGen.io.results.ready := true.B + kGen.io.results.ready := true.B + when(qGen.io.results.valid && kGen.io.results.valid) { + Qreg := qGen.io.results.bits + Kreg := kGen.io.results.bits + stateReg := state.done + } + } + is(state.done) { + qGen.io.results.ready := false.B + kGen.io.results.ready := false.B + readyReg := true.B + io.Query.valid := true.B + io.Key.valid := true.B + io.Query.bits := Qreg + io.Key.bits := Kreg + stateReg := state.idle + } + } +} + +// QKGen: use two GEMMFMATotal to get Q and K +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: Query: m * n +// output: Key: m * n +class QKGen( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val Query = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + val Key = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + val readyReg = RegInit(true.B) + io.inputToken.ready := readyReg + io.weightQ.ready := readyReg + io.weightK.ready := readyReg + io.Key.valid := false.B + io.Key.bits := DontCare + io.Query.valid := false.B + io.Query.bits := DontCare + + val qGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + val kGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) + + qGen.io.matrixA.bits := io.inputToken.bits + qGen.io.matrixA.valid := io.inputToken.valid + qGen.io.matrixB.bits := io.weightQ.bits + qGen.io.matrixB.valid := io.weightQ.valid + qGen.io.results.ready := false.B + + kGen.io.matrixA.bits := io.inputToken.bits + kGen.io.matrixA.valid := io.inputToken.valid + kGen.io.matrixB.bits := io.weightK.bits + kGen.io.matrixB.valid := io.weightK.valid + kGen.io.results.ready := false.B + + object state extends ChiselEnum { + val idle, gen, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.gen + } + } + is(state.gen) { + qGen.io.results.ready := true.B + kGen.io.results.ready := true.B + when(qGen.io.results.valid && kGen.io.results.valid) { + stateReg := state.done + } + } + is(state.done) { + qGen.io.results.ready := false.B + kGen.io.results.ready := false.B + readyReg := true.B + io.Query.valid := true.B + io.Key.valid := true.B + io.Query.bits := qGen.io.results.bits + io.Key.bits := kGen.io.results.bits + stateReg := state.idle + } + } +} + +// QKMulWithReg: use GEMMFMATotal to get scores +// input: Query: m * n +// input: Key: m * n +// output: scores: m * m +class QKMulWithReg( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val Query = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val Key = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.Query.valid && io.Key.valid + + val readyReg = RegInit(true.B) + io.Query.ready := readyReg + io.Key.ready := readyReg + io.scores.valid := false.B + io.scores.bits := DontCare + + val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) + val scoresReg = Reg(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + + QK_TMul.io.matrixA.valid := io.Query.valid + QK_TMul.io.matrixA.bits := io.Query.bits + QK_TMul.io.matrixB.valid := io.Key.valid + QK_TMul.io.matrixB.bits := VecInit(io.Key.bits.transpose.map(VecInit(_))) + QK_TMul.io.results.ready := false.B + + object state extends ChiselEnum { + val idle, mul, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.mul + } + } + is(state.mul) { + QK_TMul.io.results.ready := true.B + when(QK_TMul.io.results.valid) { + scoresReg := QK_TMul.io.results.bits + stateReg := state.done + } + } + is(state.done) { + QK_TMul.io.results.ready := false.B + readyReg := true.B + io.scores.valid := true.B + io.scores.bits := scoresReg + stateReg := state.idle + } + } +} + +// QKMul: use GEMMFMATotal to get scores +// input: Query: m * n +// input: Key: m * n +// output: scores: m * m +class QKMul( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val Query = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val Key = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.Query.valid && io.Key.valid + + val readyReg = RegInit(true.B) + io.Query.ready := readyReg + io.Key.ready := readyReg + io.scores.valid := false.B + io.scores.bits := DontCare + + val doneReg = RegInit(false.B) + + val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) + + QK_TMul.io.matrixA.valid := io.Query.valid + QK_TMul.io.matrixA.bits := io.Query.bits + QK_TMul.io.matrixB.valid := io.Key.valid + QK_TMul.io.matrixB.bits := VecInit(io.Key.bits.transpose.map(VecInit(_))) + QK_TMul.io.results.ready := false.B + + object state extends ChiselEnum { + val idle, mul, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.mul + } + } + is(state.mul) { + when(QK_TMul.io.results.valid) { + stateReg := state.done + } + } + is(state.done) { + readyReg := true.B + io.scores.valid := true.B + io.scores.bits := QK_TMul.io.results.bits + stateReg := state.idle + } + } +} + +// AttnScores: use QKGen to get Q and K, then use QKMul to get scores +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: scores: m * m +class AttnScores( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + val readyReg = RegInit(true.B) + io.inputToken.ready := readyReg + io.weightQ.ready := readyReg + io.weightK.ready := readyReg + + io.scores.valid := false.B + io.scores.bits := DontCare + + // val scoresReg = Reg(Vec(m, Vec(m, UInt(config.outputWidth.W)))) + // val QKGen = Module(new QKGenWithReg(m, k, n, peCount, gemmType)) + val QKGen = Module(new QKGen(m, k, n, peCount, gemmType)) + + QKGen.io.inputToken.valid := io.inputToken.valid + QKGen.io.inputToken.bits := io.inputToken.bits + QKGen.io.weightQ.valid := io.weightQ.valid + QKGen.io.weightQ.bits := io.weightQ.bits + QKGen.io.weightK.valid := io.weightK.valid + QKGen.io.weightK.bits := io.weightK.bits + + QKGen.io.Query.ready := false.B + QKGen.io.Key.ready := false.B + + // val QKMul = Module(new QKMulWithReg(m, n, peCount, gemmType)) + val QKMul = Module(new QKMul(m, n, peCount, gemmType)) + + QKMul.io.Query.valid := QKGen.io.Query.valid + QKMul.io.Query.bits := QKGen.io.Query.bits + QKMul.io.Key.valid := QKGen.io.Key.valid + QKMul.io.Key.bits := QKGen.io.Key.bits + QKMul.io.scores.ready := false.B + + object state extends ChiselEnum { + val idle, gen, mul, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.gen + } + } + is(state.gen) { + QKGen.io.Query.ready := true.B + QKGen.io.Key.ready := true.B + when(QKGen.io.Query.valid && QKGen.io.Key.valid) { + stateReg := state.mul + } + } + is(state.mul) { + QKGen.io.Query.ready := false.B + QKGen.io.Key.ready := false.B + QKMul.io.scores.ready := true.B + when(QKMul.io.scores.valid) { + // scoresReg := QKMul.io.scores.bits + stateReg := state.done + } + } + is(state.done) { + QKMul.io.scores.ready := false.B + readyReg := true.B + io.scores.valid := true.B + // io.scores.bits := scoresReg + io.scores.bits := QKMul.io.scores.bits + stateReg := state.idle + } + } +} diff --git a/src/main/scala/kernel/alu/GemmFMA.scala b/src/main/scala/kernel/alu/GemmFMA.scala index be78cd6..55ac1f3 100644 --- a/src/main/scala/kernel/alu/GemmFMA.scala +++ b/src/main/scala/kernel/alu/GemmFMA.scala @@ -186,8 +186,8 @@ class GEMMFMATotal( }) }) //k * peCount size block of matrixB - multiFMA.io.reset := false.B - multiFMA.io.blockResult.ready := true.B + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B val resultsReg = Reg(Vec(m, Vec(n, UInt(config.outputWidth.W)))) @@ -205,28 +205,19 @@ class GEMMFMATotal( } is(state.compute) { - // printf(p"rowIndex: ${rowIndex.value}, colIndex: ${colIndex.value}\n") - // printf(p"matrixA_row: ${multiFMA.io.matrixA_row.bits}\n") - // for (i <- 0 until peCount) { - // printf(p"matrixB_cols(${i}): ${multiFMA.io.matrixB_cols.bits(i)}\n") - // } multiFMA.io.reset := false.B + multiFMA.io.blockResult.ready := true.B when(multiFMA.io.blockResult.valid) { for (i <- 0 until peCount) { resultsReg(rowIndex.value)((colIndex.value * peCount.U + i.U) % n.U) := multiFMA.io.blockResult.bits(i) } - // for (i <- 0 until m) { - // for (j <- 0 until n) { - // printf(p"i:${i},j:${j},${resultsReg(i)(j)}\t") - // } - // printf(p"\n") - // } stateReg := state.update } } is(state.update) { multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B when(colIndex.inc()) { when(rowIndex.inc()) { stateReg := state.done @@ -384,806 +375,3 @@ class GEMMSingleQueue( io.done := gemm.io.done } - -// first use GEMMFMATotal to get Q and K, then use GEMMFMASingle to get Q*K^T -// out one row of score matrix -class AttnScoresSingle( - val m: Int, - val k: Int, - val n: Int, - val peCount1: Int = 16, - val peCount2: Int = 16, - val gemmType: GEMMDataType.Type, - val bufferSizeGemm: Int = 32 -)( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - val io = IO(new Bundle { - val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) - val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val scoreRow = Decoupled(new currentRowIndex(m, m)) - val resetBuffer = Input(Bool()) - val done = Output(Bool()) - }) - - val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - - val readyReg = RegInit(true.B) - io.inputToken.ready := readyReg - io.weightQ.ready := readyReg - io.weightK.ready := readyReg - io.scoreRow.valid := false.B - io.scoreRow.bits := DontCare - io.done := false.B - - //use GEMMFMATotal to get Q and K - val qGen = Module(new GEMMFMATotal(m, k, n, peCount1, gemmType)) - val kGen = Module(new GEMMFMATotal(m, k, n, peCount1, gemmType)) - qGen.io.matrixA <> io.inputToken - qGen.io.matrixB <> io.weightQ - kGen.io.matrixA <> io.inputToken - kGen.io.matrixB <> io.weightK - - // when qGen and kGen are done, use GEMMFMASingle to get Q*K^T - // Q: m * n, K: m * n -> Q*K^T: m * m - val QK_TMul = Module(new GEMMSingleQueue(m, n, m, peCount2, gemmType, bufferSizeGemm)) - QK_TMul.io.matrixA <> qGen.io.results - - val K_T = VecInit(Seq.fill(n)(VecInit(Seq.fill(m)(0.U(config.inputWidth.W))))) - for (i <- 0 until k) { - for (j <- 0 until n) { - K_T(i)(j) := kGen.io.results.bits(j)(i) - } - } - - QK_TMul.io.matrixB.valid := kGen.io.results.valid - // QK_TMul.io.matrixB.bits := VecInit(kGen.io.results.bits.transpose.map(VecInit(_))) - QK_TMul.io.matrixB.bits := K_T - kGen.io.results.ready := QK_TMul.io.matrixB.ready - - QK_TMul.io.flush := io.resetBuffer - io.scoreRow <> QK_TMul.io.currentRow - - object state extends ChiselEnum { - val idle, gen, mul, collect, done = Value - } - val stateReg = RegInit(state.idle) - - switch(stateReg) { - is(state.idle) { - when(dataValid) { - readyReg := false.B - stateReg := state.gen - } - } - is(state.gen) { - when(qGen.io.results.valid && kGen.io.results.valid) { - debugLog(p"qGen results: ${qGen.io.results.bits}\n") - debugLog(p"kGen results: ${kGen.io.results.bits}\n") - stateReg := state.mul - } - } - is(state.mul) { - when(QK_TMul.io.currentRow.valid) { - stateReg := state.collect - } - } - is(state.collect) { - when(QK_TMul.io.done) { - stateReg := state.done - } - } - is(state.done) { - io.done := true.B - readyReg := true.B - stateReg := state.idle - } - } - -} - -class AttnScoresTotal( - val m: Int, - val k: Int, - val n: Int, - val peCount: Int, - val gemmType: GEMMDataType.Type -)( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - val io = IO(new Bundle { - val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) - val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) - }) - - val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - - val readyReg = RegInit(true.B) - io.inputToken.ready := readyReg - io.weightQ.ready := readyReg - io.weightK.ready := readyReg - io.scores.valid := false.B - io.scores.bits := DontCare - - //use GEMMFMATotal to get Q and K - val qGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) - val kGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) - qGen.io.matrixA := DontCare - qGen.io.matrixB := DontCare - kGen.io.matrixA := DontCare - kGen.io.matrixB := DontCare - qGen.io.results.ready := false.B - kGen.io.results.ready := false.B - - // val Qreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) - // val Kreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) - - // when qGen and kGen are done, use GEMMFMASingle to get Q*K^T - // Q: m * n, K: m * n -> Q*K^T: m * m - val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) - QK_TMul.io.matrixA := DontCare - QK_TMul.io.matrixB := DontCare - QK_TMul.io.results.ready := false.B - - // object state extends ChiselEnum { - // val idle, gen, mul, done = Value - // } - // val stateReg = RegInit(state.idle) - - // switch(stateReg) { - // is(state.idle) { - // when(dataValid) { - // readyReg := false.B - // stateReg := state.gen - // } - // } - // is(state.gen) { - // when(qGen.io.results.valid && kGen.io.results.valid) { - // // Qreg := qGen.io.results.bits - // // Kreg := kGen.io.results.bits - // stateReg := state.mul - // } - // } - // is(state.mul) { - // when(QK_TMul.io.results.valid) { - // stateReg := state.done - // } - // } - // is(state.done) { - // readyReg := true.B - // stateReg := state.idle - // } - // } - - when(dataValid) { - readyReg := false.B - qGen.io.matrixA <> io.inputToken - qGen.io.matrixB <> io.weightQ - kGen.io.matrixA <> io.inputToken - kGen.io.matrixB <> io.weightK - qGen.io.results.ready := true.B - kGen.io.results.ready := true.B - when(qGen.io.results.valid && kGen.io.results.valid) { - QK_TMul.io.matrixA.valid := qGen.io.results.valid - QK_TMul.io.matrixA.bits := qGen.io.results.bits - qGen.io.results.ready := true.B - QK_TMul.io.matrixB.valid := kGen.io.results.valid - QK_TMul.io.matrixB.bits := VecInit(kGen.io.results.bits.transpose.map(VecInit(_))) - kGen.io.results.ready := true.B - when(QK_TMul.io.results.valid) { - io.scores.valid := QK_TMul.io.results.valid - io.scores.bits := QK_TMul.io.results.bits - readyReg := true.B - } - } - } -} - -// QKGenWithReg: use two GEMMFMATotal to get Q and K -// whthin use Reg to store Q and K -// input: inputToken: m * k -// input: weightQ: k * n -// input: weightK: k * n -// output: Query: m * n -// output: Key: m * n -class QKGenWithReg( - val m: Int, - val k: Int, - val n: Int, - val peCount: Int = 16, - val gemmType: GEMMDataType.Type -)( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - val io = IO(new Bundle { - val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) - val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val Query = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) - val Key = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) - }) - - val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - - val readyReg = RegInit(true.B) - io.inputToken.ready := readyReg - io.weightQ.ready := readyReg - io.weightK.ready := readyReg - io.Key.valid := false.B - io.Key.bits := DontCare - io.Query.valid := false.B - io.Query.bits := DontCare - - val qGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) - val kGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) - - val Qreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) - val Kreg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) - - qGen.io.matrixA.bits := io.inputToken.bits - qGen.io.matrixA.valid := io.inputToken.valid - qGen.io.matrixB.bits := io.weightQ.bits - qGen.io.matrixB.valid := io.weightQ.valid - qGen.io.results.ready := true.B - - kGen.io.matrixA.bits := io.inputToken.bits - kGen.io.matrixA.valid := io.inputToken.valid - kGen.io.matrixB.bits := io.weightK.bits - kGen.io.matrixB.valid := io.weightK.valid - kGen.io.results.ready := true.B - - io.Query.valid := false.B - io.Query.bits := Qreg - io.Key.valid := false.B - io.Key.bits := Kreg - - // qGen.io.matrixA <> io.inputToken - // qGen.io.matrixB <> io.weightQ - // kGen.io.matrixA <> io.inputToken - // kGen.io.matrixB <> io.weightK - - // io.Query <> qGen.io.results - // io.Key <> kGen.io.results - - object state extends ChiselEnum { - val idle, gen, done = Value - } - val stateReg = RegInit(state.idle) - - switch(stateReg) { - is(state.idle) { - when(dataValid) { - readyReg := false.B - stateReg := state.gen - } - } - is(state.gen) { - when(qGen.io.results.valid && kGen.io.results.valid) { - Qreg := qGen.io.results.bits - Kreg := kGen.io.results.bits - stateReg := state.done - } - } - is(state.done) { - readyReg := true.B - io.Query.valid := true.B - io.Key.valid := true.B - stateReg := state.idle - } - } -} - -// QKGen: use two GEMMFMATotal to get Q and K -// input: inputToken: m * k -// input: weightQ: k * n -// input: weightK: k * n -// output: Query: m * n -// output: Key: m * n -class QKGen( - val m: Int, - val k: Int, - val n: Int, - val peCount: Int = 16, - val gemmType: GEMMDataType.Type -)( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - val io = IO(new Bundle { - val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) - val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val Query = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) - val Key = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) - }) - - val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - - val readyReg = RegInit(true.B) - io.inputToken.ready := readyReg - io.weightQ.ready := readyReg - io.weightK.ready := readyReg - io.Key.valid := false.B - io.Key.bits := DontCare - io.Query.valid := false.B - io.Query.bits := DontCare - - val qGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) - val kGen = Module(new GEMMFMATotal(m, k, n, peCount, gemmType)) - - // qGen.io.matrixA.bits := io.inputToken.bits - // qGen.io.matrixA.valid := io.inputToken.valid - // qGen.io.matrixB.bits := io.weightQ.bits - // qGen.io.matrixB.valid := io.weightQ.valid - // qGen.io.results.ready := true.B - - // kGen.io.matrixA.bits := io.inputToken.bits - // kGen.io.matrixA.valid := io.inputToken.valid - // kGen.io.matrixB.bits := io.weightK.bits - // kGen.io.matrixB.valid := io.weightK.valid - // kGen.io.results.ready := true.B - - // io.Query.valid := qGen.io.results.valid - // io.Query.bits := qGen.io.results.bits - // io.Key.valid := kGen.io.results.valid - // io.Key.bits := kGen.io.results.bits - - qGen.io.matrixA <> io.inputToken - qGen.io.matrixB <> io.weightQ - kGen.io.matrixA <> io.inputToken - kGen.io.matrixB <> io.weightK - - io.Query <> qGen.io.results - io.Key <> kGen.io.results - - object state extends ChiselEnum { - val idle, gen, done = Value - } - val stateReg = RegInit(state.idle) - - switch(stateReg) { - is(state.idle) { - when(dataValid) { - readyReg := false.B - stateReg := state.gen - } - } - is(state.gen) { - when(qGen.io.results.valid && kGen.io.results.valid) { - stateReg := state.done - } - } - is(state.done) { - readyReg := true.B - io.Query.valid := true.B - io.Key.valid := true.B - stateReg := state.idle - } - } -} - -// QKMulTotalWithReg: use GEMMFMATotal to get scores -// input: Query: m * n -// input: Key: m * n -// output: scores: m * m -class QKMulTotalWithReg( - val m: Int, - val n: Int, - val peCount: Int = 16, - val gemmType: GEMMDataType.Type -)( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - val io = IO(new Bundle { - val Query = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) - val Key = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) - val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) - }) - - val dataValid = io.Query.valid && io.Key.valid - - val readyReg = RegInit(true.B) - io.Query.ready := readyReg - io.Key.ready := readyReg - io.scores.valid := false.B - io.scores.bits := DontCare - - val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) - val scoresReg = Reg(Vec(m, Vec(m, UInt(config.outputWidth.W)))) - - QK_TMul.io.matrixA.valid := io.Query.valid - QK_TMul.io.matrixA.bits := io.Query.bits - QK_TMul.io.matrixB.valid := io.Key.valid - QK_TMul.io.matrixB.bits := VecInit(io.Key.bits.transpose.map(VecInit(_))) - QK_TMul.io.results.ready := true.B - - // io.scores.valid := QK_TMul.io.results.valid - // io.scores.bits := QK_TMul.io.results.bits - io.scores.valid := false.B - io.scores.bits := scoresReg - - object state extends ChiselEnum { - val idle, mul, done = Value - } - val stateReg = RegInit(state.idle) - - switch(stateReg) { - is(state.idle) { - when(dataValid) { - readyReg := false.B - stateReg := state.mul - } - } - is(state.mul) { - when(QK_TMul.io.results.valid) { - // for (i <- 0 until m) { - // for (j <- 0 until n) { - // // printf(p"QK_TMul.io.matrixA.bits($i)($j): ${QK_TMul.io.matrixA.bits(i)(j)}\n") - // // printf(p"QK_TMul.io.matrixB.bits($i)($j): ${QK_TMul.io.matrixB.bits(i)(j)}\n") - // printf(p"io.Query.bits($i)($j): ${io.Query.bits(i)(j)}\n") - // printf(p"io.Key.bits($i)($j): ${io.Key.bits(i)(j)}\n") - // } - // } - // for (i <- 0 until m) { - // for (j <- 0 until m) { - // printf(p"QK_TMul.io.results.bits($i)($j): ${QK_TMul.io.results.bits(i)(j)}\n") - // } - // } - scoresReg := QK_TMul.io.results.bits - stateReg := state.done - } - } - is(state.done) { - readyReg := true.B - io.scores.valid := true.B - stateReg := state.idle - } - } -} - -// QKMulTotal: use GEMMFMATotal to get scores -// input: Query: m * n -// input: Key: m * n -// output: scores: m * m -class QKMulTotal( - val m: Int, - val n: Int, - val peCount: Int = 16, - val gemmType: GEMMDataType.Type -)( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - val io = IO(new Bundle { - val Query = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) - val Key = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) - val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) - }) - - val dataValid = io.Query.valid && io.Key.valid - - val readyReg = RegInit(true.B) - io.Query.ready := readyReg - io.Key.ready := readyReg - io.scores.valid := false.B - io.scores.bits := DontCare - - val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) - - QK_TMul.io.matrixA.valid := io.Query.valid - QK_TMul.io.matrixA.bits := io.Query.bits - QK_TMul.io.matrixB.valid := io.Key.valid - QK_TMul.io.matrixB.bits := VecInit(io.Key.bits.transpose.map(VecInit(_))) - QK_TMul.io.results.ready := true.B - - io.scores.valid := QK_TMul.io.results.valid - io.scores.bits := QK_TMul.io.results.bits - - object state extends ChiselEnum { - val idle, mul, done = Value - } - val stateReg = RegInit(state.idle) - - switch(stateReg) { - is(state.idle) { - when(dataValid) { - readyReg := false.B - stateReg := state.mul - } - } - is(state.mul) { - when(QK_TMul.io.results.valid) { - // for (i <- 0 until m) { - // for (j <- 0 until n) { - // // printf(p"QK_TMul.io.matrixA.bits($i)($j): ${QK_TMul.io.matrixA.bits(i)(j)}\n") - // // printf(p"QK_TMul.io.matrixB.bits($i)($j): ${QK_TMul.io.matrixB.bits(i)(j)}\n") - // printf(p"io.Query.bits($i)($j): ${io.Query.bits(i)(j)}\n") - // printf(p"io.Key.bits($i)($j): ${io.Key.bits(i)(j)}\n") - // } - // } - // for (i <- 0 until m) { - // for (j <- 0 until m) { - // printf(p"QK_TMul.io.results.bits($i)($j): ${QK_TMul.io.results.bits(i)(j)}\n") - // } - // } - stateReg := state.done - } - } - is(state.done) { - readyReg := true.B - io.scores.valid := true.B - stateReg := state.idle - } - } -} - -// AttnScores: use QKGen to get Q and K, then use QKMulTotal to get scores -// input: inputToken: m * k -// input: weightQ: k * n -// input: weightK: k * n -// output: scores: m * m -class AttnScores( - val m: Int, - val k: Int, - val n: Int, - val peCount: Int = 16, - val gemmType: GEMMDataType.Type -)( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - val io = IO(new Bundle { - val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) - val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val scores = Decoupled(Vec(m, Vec(m, UInt(config.outputWidth.W)))) - }) - - val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - - val readyReg = RegInit(true.B) - io.inputToken.ready := readyReg - io.weightQ.ready := readyReg - io.weightK.ready := readyReg - - io.scores.valid := false.B - io.scores.bits := DontCare - - val QKGen = Module(new QKGen(m, k, n, peCount, gemmType)) - - QKGen.io.inputToken.valid := io.inputToken.valid - QKGen.io.inputToken.bits := io.inputToken.bits - QKGen.io.weightQ.valid := io.weightQ.valid - QKGen.io.weightQ.bits := io.weightQ.bits - QKGen.io.weightK.valid := io.weightK.valid - QKGen.io.weightK.bits := io.weightK.bits - QKGen.io.Query.ready := true.B - QKGen.io.Key.ready := true.B - - // val QueryReg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) - // val KeyReg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) - - val QKMul = Module(new QKMulTotalWithReg(m, n, peCount, gemmType)) - - val scoresReg = Reg(Vec(m, Vec(m, UInt(config.inputWidth.W)))) - - QKMul.io.Query.valid := QKGen.io.Query.valid - QKMul.io.Query.bits := QKGen.io.Query.bits - // QKMul.io.Query.bits := QueryReg - QKMul.io.Key.valid := QKGen.io.Key.valid - QKMul.io.Key.bits := QKGen.io.Key.bits - // QKMul.io.Key.bits := KeyReg - QKMul.io.scores.ready := true.B - - // io.scores.valid := QKMul.io.scores.valid - // io.scores.bits := QKMul.io.scores.bits - io.scores.valid := false.B - io.scores.bits := scoresReg - - object state extends ChiselEnum { - val idle, gen, mul, done = Value - } - val stateReg = RegInit(state.idle) - - switch(stateReg) { - is(state.idle) { - when(dataValid) { - readyReg := false.B - stateReg := state.gen - } - } - is(state.gen) { - printf(p"gen:\n") - when(QKGen.io.Query.valid && QKGen.io.Key.valid) { - // QueryReg := QKGen.io.Query.bits - // KeyReg := QKGen.io.Key.bits - // for (i <- 0 until m) { - // for (j <- 0 until n) { - // // printf(p"QueryReg($i)($j): ${QueryReg(i)(j)}\n") - // // printf(p"KeyReg($i)($j): ${KeyReg(i)(j)}\n") - // printf(p"QKGen.io.Query.bits($i)($j): ${QKGen.io.Query.bits(i)(j)}\n") - // printf(p"QKGen.io.Key.bits($i)($j): ${QKGen.io.Key.bits(i)(j)}\n") - // } - // } - stateReg := state.mul - } - } - is(state.mul) { - printf(p"mul:\n") - when(QKMul.io.scores.valid) { - // for (i <- 0 until m) { - // for (j <- 0 until n) { - // printf(p"QKMul.io.Query.bits($i)($j): ${QKMul.io.Query.bits(i)(j)}\n") - // printf(p"QKMul.io.Key.bits($i)($j): ${QKMul.io.Key.bits(i)(j)}\n") - // } - // } - // for (i <- 0 until m) { - // for (j <- 0 until m) { - // printf(p"QKMul.io.scores.bits($i)($j): ${QKMul.io.scores.bits(i)(j)}\n") - // } - // } - scoresReg := QKMul.io.scores.bits - stateReg := state.done - } - } - is(state.done) { - readyReg := true.B - io.scores.valid := true.B - stateReg := state.idle - } - } - -} - -// OutValue: get the final output value -// input: AttnWeights: m * m -// input: Value: m * n -// output: AttnOut: m * n -class OutValue( - val m: Int, - val n: Int, - val peCount: Int = 16, - val gemmType: GEMMDataType.Type -)( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - val io = IO(new Bundle { - val AttnWeights = Flipped(Decoupled(Vec(m, Vec(m, UInt(config.inputWidth.W))))) - val Value = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) - val AttnOut = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) - }) - - val dataValid = io.AttnWeights.valid && io.Value.valid - - val readyReg = RegInit(true.B) - io.AttnWeights.ready := readyReg - io.Value.ready := readyReg - io.AttnOut.valid := false.B - io.AttnOut.bits := DontCare - - val ValueMul = Module(new GEMMFMATotal(m, m, n, peCount, gemmType)) - - ValueMul.io.matrixA.valid := io.AttnWeights.valid - ValueMul.io.matrixA.bits := io.AttnWeights.bits - ValueMul.io.matrixB.valid := io.Value.valid - ValueMul.io.matrixB.bits := io.Value.bits - ValueMul.io.results.ready := true.B - - io.AttnOut.valid := ValueMul.io.results.valid - io.AttnOut.bits := ValueMul.io.results.bits - -} - -// OutValue: get the final output value -// input: one row of AttnWeights: 1 * m ,total m rows -// input: Value: m * n -// output: one row of AttnOut: 1 * n ,total m rows -// output: done: Bool -class OutValueSingle( - val m: Int, - val n: Int, - val peCount: Int = 16, - val gemmType: GEMMDataType.Type -)( - implicit config: DataWidthConfig) - extends Module - with DebugLog { - val io = IO(new Bundle { - val currentAttnW = Flipped(Decoupled(new currentRowIndex(m, m))) - val Value = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) - val currentAttnO = Decoupled(new currentRowIndex(m, n)) - val done = Output(Bool()) - }) - - val dataValid = io.currentAttnW.valid && io.Value.valid - - val readyReg = RegInit(true.B) - io.currentAttnW.ready := readyReg - io.Value.ready := readyReg - io.currentAttnO.valid := false.B - io.currentAttnO.bits := DontCare - io.done := false.B - - val multiFMA = Module(new MultiFMA(m, peCount, gemmType)) - - val rowIndex = Counter(m) - val colIndex = Counter(n / peCount) - - multiFMA.io.matrixA_row.valid := io.currentAttnW.valid - multiFMA.io.matrixA_row.bits := io.currentAttnW.bits.value - - multiFMA.io.matrixB_cols.valid := io.Value.valid - multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(m) { j => - VecInit(Seq.tabulate(peCount) { i => - io.Value.bits(j)((colIndex.value * peCount.U + i.U) % n.U) - }) - }) //m * peCount size block of Value - - multiFMA.io.reset := false.B - multiFMA.io.blockResult.ready := true.B - - val currentRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) - - object state extends ChiselEnum { - val idle, compute, update, done = Value - } - - val stateReg = RegInit(state.idle) - - switch(stateReg) { - is(state.idle) { - when(dataValid) { - // readyReg := false.B - stateReg := state.compute - } - } - - is(state.compute) { - multiFMA.io.reset := false.B - // printf(p"multiFMA.io.matrixA_row.bits: ${multiFMA.io.matrixA_row.bits}\n") - // printf(p"multiFMA.io.matrixB_cols.bits: ${multiFMA.io.matrixB_cols.bits}\n") - - when(multiFMA.io.blockResult.valid) { - for (i <- 0 until peCount) { - currentRowReg(colIndex.value * peCount.U + i.U) := multiFMA.io.blockResult.bits(i) - } - stateReg := state.update - } - } - - is(state.update) { - multiFMA.io.reset := true.B - io.currentAttnO.valid := false.B - when(colIndex.inc()) { - io.currentAttnO.valid := true.B - io.currentAttnO.bits.index := rowIndex.value - io.currentAttnO.bits.value := currentRowReg - // readyReg := true.B - // io.currentAttnO.ready := true.B - // io.Value.ready := true.B - when(rowIndex.inc()) { - stateReg := state.done - }.otherwise { - stateReg := state.compute - // wait for next row of AttnWeights - } - }.otherwise { - stateReg := state.compute - } - - } - is(state.done) { - io.done := true.B - readyReg := true.B - stateReg := state.idle - } - } -} diff --git a/src/main/scala/kernel/alu/OutValue.scala b/src/main/scala/kernel/alu/OutValue.scala new file mode 100644 index 0000000..36547ab --- /dev/null +++ b/src/main/scala/kernel/alu/OutValue.scala @@ -0,0 +1,177 @@ +package kernel.alu + +import chisel3._ +import chisel3.util._ +import kernel.alu.GEMMDataType +import kernel.alu.DataWidthConfig +import kernel.utils.DebugLog +import kernel.deprecated.PE + +// OutValue: get the final output value +// input: AttnWeights: m * m +// input: Value: m * n +// output: AttnOut: m * n +class OutValue( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val Scores = Flipped(Decoupled(Vec(m, Vec(m, UInt(config.inputWidth.W))))) + val Value = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val AttnOut = Decoupled(Vec(m, Vec(n, UInt(config.outputWidth.W)))) + }) + + val dataValid = io.Scores.valid && io.Value.valid + + val readyReg = RegInit(true.B) + io.Scores.ready := readyReg + io.Value.ready := readyReg + io.AttnOut.valid := false.B + io.AttnOut.bits := DontCare + + val ValueMul = Module(new GEMMFMATotal(m, m, n, peCount, gemmType)) + + ValueMul.io.matrixA.valid := io.Scores.valid + ValueMul.io.matrixA.bits := io.Scores.bits + ValueMul.io.matrixB.valid := io.Value.valid + ValueMul.io.matrixB.bits := io.Value.bits + ValueMul.io.results.ready := false.B + + object state extends ChiselEnum { + val idle, compute, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.compute + } + } + is(state.compute) { + ValueMul.io.results.ready := true.B + when(ValueMul.io.results.valid) { + stateReg := state.done + } + } + is(state.done) { + ValueMul.io.results.ready := false.B + readyReg := true.B + io.AttnOut.valid := true.B + io.AttnOut.bits := ValueMul.io.results.bits + stateReg := state.idle + } + } +} + +// OutValue: get the final output value +// input: one row of AttnWeights: 1 * m ,total m rows +// input: Value: m * n +// output: one row of AttnOut: 1 * n ,total m rows +// output: done: Bool +class OutValueSingle( + val m: Int, + val n: Int, + val peCount: Int, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val currentScores = Flipped(Decoupled(new currentRowIndex(m, m))) + val Value = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val currentAttnOut = Decoupled(new currentRowIndex(m, n)) + val done = Output(Bool()) + }) + + val dataValid = io.currentScores.valid && io.Value.valid + + io.currentScores.ready := true.B + io.Value.ready := true.B + io.currentAttnOut.valid := false.B + io.currentAttnOut.bits := DontCare + io.done := false.B + + val ValueReg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) + ValueReg := io.Value.bits + + val multiFMA = Module(new MultiFMA(m, peCount, gemmType)) + + val rowIndex = Counter(m) + val colIndex = Counter(n / peCount) + + multiFMA.io.matrixA_row.valid := io.currentScores.valid + multiFMA.io.matrixA_row.bits := io.currentScores.bits.value + + multiFMA.io.matrixB_cols.valid := io.Value.valid + multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(m) { j => + VecInit(Seq.tabulate(peCount) { i => + ValueReg(j)(((colIndex.value << log2Ceil(peCount).U) + i.U)(log2Ceil(n)-1, 0)) + }) + }) //m * peCount size block of Value + + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B + + val currentRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) + + object state extends ChiselEnum { + val idle, compute, update, load, done = Value + } + + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + io.Value.ready := false.B + stateReg := state.compute + } + } + is(state.compute) { + io.currentScores.ready := false.B + multiFMA.io.reset := false.B + multiFMA.io.blockResult.ready := true.B + when(multiFMA.io.blockResult.valid) { + for (i <- 0 until peCount) { + currentRowReg(colIndex.value * peCount.U + i.U) := multiFMA.io.blockResult.bits(i) + } + stateReg := state.update + } + } + is(state.update) { + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B + io.currentAttnOut.valid := false.B + when(colIndex.inc()) { + io.currentAttnOut.valid := true.B + io.currentAttnOut.bits.index := rowIndex.value + io.currentAttnOut.bits.value := currentRowReg + when(rowIndex.inc()) { + stateReg := state.done + }.otherwise { + stateReg := state.load + } + }.otherwise { + stateReg := state.compute + } + } + is(state.load) { + io.currentScores.ready := true.B + stateReg := state.compute + } + is(state.done) { + io.done := true.B + io.Value.ready := true.B + io.currentScores.ready := true.B + stateReg := state.idle + } + } +} diff --git a/src/test/scala/kernel/alu/AttnScoresTest.scala b/src/test/scala/kernel/alu/AttnScoresTest.scala new file mode 100644 index 0000000..8f0556a --- /dev/null +++ b/src/test/scala/kernel/alu/AttnScoresTest.scala @@ -0,0 +1,281 @@ +package kernel.alu + +import chisel3._ +import chiseltest._ +import org.scalatest.freespec.AnyFreeSpec +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.ParallelTestExecution +import scala.reflect.ClassTag +import kernel.alu.{DataWidthConfig, Fp32Config, Fp64Config, FxpConfig, GEMMDataType} +import ujson.Arr +import Utils._ + +class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { + + private def testQKGen[T: Numeric: ClassTag]( + dut: QKGen + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val Query = mmul(inputToken, weightQ) + val Key = mmul(inputToken, weightK) + + printmat(Query) + printmat(Key) + + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + + while (!(dut.io.Key.valid.peekBoolean() && dut.io.Query.valid.peekBoolean())) { + dut.clock.step() + } + + dut.io.Key.ready.poke(true.B) + dut.io.Query.ready.poke(true.B) + + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until n + } { + val outBigInt = dut.io.Query.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = Query(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testQKMul[T: Numeric: ClassTag]( + dut: QKMul + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + + val Query = matInit[T](m, n) + val Key = matInit[T](m, n) + val expectedResults = mmul(Query, Key.transpose) + + println("Query:") + printmat(Query) + println("Key:") + printmat(Key) + println("expectedResults:") + printmat(expectedResults) + + if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { + println("Query and Key are ready") + dut.io.Query.valid.poke(true.B) + dut.io.Key.valid.poke(true.B) + + for { + row <- 0 until m + col <- 0 until n + } { + dut.io.Query.bits(row)(col).poke(toBinaryBigInt(Query(row)(col)).U) + dut.io.Key.bits(row)(col).poke(toBinaryBigInt(Key(row)(col)).U) + } + } else { + dut.io.Query.valid.poke(false.B) + dut.io.Key.valid.poke(false.B) + } + + while (!dut.io.scores.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + + for { + row <- 0 until m + col <- 0 until m + } { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testAttnScores[T: Numeric: ClassTag]( + dut: AttnScores + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val Query = mmul(inputToken, weightQ) + val Key = mmul(inputToken, weightK) + val expectedResults = mmul(Query, Key.transpose) + + print("Query:\n") + printmat(Query) + print("Key:\n") + printmat(Key) + print("expectedResults:\n") + printmat(expectedResults) + + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + + while (!dut.io.scores.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + + for { + row <- 0 until m + col <- 0 until m + } { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + // "AttnScoresTotal " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new AttnScoresTotal(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testAttnScoresTotal[Int](dut) + // } + // } + + // "QKMul " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKMul(m = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKMul[Int](dut) + // } + // } +// "AttnScores " should "compute fxp matrix multiplication" in { +// implicit val config: DataWidthConfig = FxpConfig +// test(new AttnScores(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) +// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => +// testAttnScores[Int](dut) +// } +// } + + // "AttnScores " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new AttnScores(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testAttnScores[Float](dut) + // } + // } + + // "QKMul " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKMul(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKMul[Int](dut) + // } + // } + +// "QKMulWithReg " should "compute fxp matrix multiplication" in { +// implicit val config: DataWidthConfig = FxpConfig +// test(new QKMulWithReg(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) +// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => +// testQKMulWithReg[Int](dut) +// } +// } + + "QKGen " should "compute fxp matrix multiplication" in { + implicit val config: DataWidthConfig = FxpConfig + test(new QKGen(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + testQKGen[Int](dut) + } + } + +// "QKGenWithReg " should "compute fxp matrix multiplication" in { +// implicit val config: DataWidthConfig = FxpConfig +// test(new QKGenWithReg(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) +// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => +// testQKGenWithReg[Int](dut) +// } +// } +} diff --git a/src/test/scala/kernel/alu/GemmFMATest.scala b/src/test/scala/kernel/alu/GemmFMATest.scala index 87259d9..2700057 100644 --- a/src/test/scala/kernel/alu/GemmFMATest.scala +++ b/src/test/scala/kernel/alu/GemmFMATest.scala @@ -7,187 +7,11 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.ParallelTestExecution import scala.reflect.ClassTag import kernel.alu.{DataWidthConfig, Fp32Config, Fp64Config, FxpConfig, GEMMDataType} +import ujson.Arr +import Utils._ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestExecution { - def mmul[T: Numeric: ClassTag](a: Array[Array[T]], b: Array[Array[T]]): Array[Array[T]] = { - val rows = a.length - val cols = b(0).length - val n = b.length - val num = implicitly[Numeric[T]] - - Array.tabulate(rows, cols) { (i, j) => - var sum = num.zero - for (k <- 0 until n) { - sum = num.plus(sum, num.times(a(i)(k), b(k)(j))) - } - sum - } - } - - def matInit[T: Numeric: ClassTag]( - rows: Int, - cols: Int - )( - implicit config: DataWidthConfig - ): Array[Array[T]] = { - val r = new scala.util.Random(42) - val ct = implicitly[ClassTag[T]] - val numeric = implicitly[Numeric[T]] - - ct.runtimeClass match { - case c if c == classOf[Int] => - // 定点数使用 -8 到 7 的整数 - Array.fill(rows, cols)( - numeric.fromInt( - // r.nextInt(math.pow(2, config.inputWidth).toInt) - math.pow(2, config.inputWidth - 1).toInt - r.nextInt(4) - 2 - ) - ) - case c if c == classOf[Float] => - // 32位浮点数使用 -1 到 1 的随机浮点数 - // Float 类型 - Array.fill(rows, cols)((r.nextFloat() * 2 - 1).asInstanceOf[T]) - case c if c == classOf[Double] => - // 64位浮点数使用 -1 到 1 的随机浮点数 - Array.fill(rows, cols)((r.nextDouble() * 2 - 1).asInstanceOf[T]) - case _ => - throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") - } - } - - def toSignedBigInt(value: BigInt, width: Int): BigInt = { - val signBit = (value >> (width - 1)) & 1 - - if (signBit == 1) { - val maxValue = BigInt(1) << width - value - maxValue - } else { - value - } - } - - def printmat[T: Numeric: ClassTag](m: Array[Array[T]]): Unit = { - val numeric = implicitly[Numeric[T]] - val ct = implicitly[ClassTag[T]] - - m.foreach { r => - r.foreach { v => - ct.runtimeClass match { - case c if c == classOf[Float] => - print(f"${v.asInstanceOf[Float]}%.4f\t") - case c if c == classOf[Double] => - print(f"${v.asInstanceOf[Double]}%.4f\t") - case c if c == classOf[Int] => - print(f"${v.asInstanceOf[Int]}%d\t") - case _ => - throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") - } - } - println(";") - } - println() - } - - def printmat[T: Numeric: ClassTag](m: Array[T], x: Int, y: Int)(implicit config: DataWidthConfig): Unit = { - val numeric = implicitly[Numeric[T]] - val ct = implicitly[ClassTag[T]] - - for (i <- 0 until x) { - for (j <- 0 until y) { - ct.runtimeClass match { - case c if c == classOf[Float] => - print(f"${m(i * y + j).asInstanceOf[Float]}%.4f\t") - case c if c == classOf[Double] => - print(f"${m(i * y + j).asInstanceOf[Double]}%.4f\t") - case c if c == classOf[Int] => - print(f"${m(i * y + j).asInstanceOf[Int]}%d\t") - case c if c == classOf[BigInt] => - print(f"${toSignedBigInt(m(i * y + j).asInstanceOf[BigInt], config.inputWidth)}%d\t") - case _ => - throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") - } - } - println(";") - } - println() - } - - // convert T to binary bigInt - def toBinaryBigInt[T: Numeric: ClassTag](v: T)(implicit config: DataWidthConfig): BigInt = { - val ct = implicitly[ClassTag[T]] - val num = implicitly[Numeric[T]] - - ct.runtimeClass match { - case c if c == classOf[Int] => - val intValue = v.asInstanceOf[Int] - // 使用 inputWidth 位来表示所有整数,保持符号位 - val mask = (1L << config.inputWidth) - 1 - BigInt(intValue) & mask - case c if c == classOf[Float] => - BigInt(java.lang.Float.floatToRawIntBits(v.asInstanceOf[Float]).toBinaryString, 2) - case c if c == classOf[Double] => - BigInt(java.lang.Double.doubleToRawLongBits(v.asInstanceOf[Double]).toBinaryString, 2) - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") - } - } - - // convrt T to binary string - private def toBinaryString[T: Numeric: ClassTag](v: T)(implicit config: DataWidthConfig): String = { - val ct = implicitly[ClassTag[T]] - val num = implicitly[Numeric[T]] - - ct.runtimeClass match { - case c if c == classOf[Int] => - val intBValue = v.asInstanceOf[Int].toBinaryString - if (intBValue.length < config.inputWidth) { - intBValue.reverse.padTo(config.inputWidth, '0').reverse - } else { - intBValue.takeRight(config.inputWidth) - } - case c if c == classOf[Float] => - java.lang.Float - .floatToRawIntBits(v.asInstanceOf[Float]) - .toBinaryString - .reverse - .padTo(config.inputWidth, '0') - .reverse - case c if c == classOf[Double] => - java.lang.Double - .doubleToRawLongBits(v.asInstanceOf[Double]) - .toBinaryString - .reverse - .padTo(config.inputWidth, '0') - .reverse - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") - } - } - - // convert binary bigInt to T - def fromBinaryBigInt[T: Numeric: ClassTag](bigInt: BigInt)(implicit config: DataWidthConfig): T = { - val ct = implicitly[ClassTag[T]] - - ct.runtimeClass match { - case c if c == classOf[Int] => - val intValue = bigInt.toInt - // 处理符号位 - val signExtendedValue = if ((intValue & (1 << (config.inputWidth - 1))) != 0) { - intValue | ~((1 << config.inputWidth) - 1) - } else { - intValue - } - signExtendedValue.asInstanceOf[T] - case c if c == classOf[Float] => - java.lang.Float.intBitsToFloat(bigInt.toInt).asInstanceOf[T] - case c if c == classOf[Double] => - java.lang.Double.longBitsToDouble(bigInt.toLong).asInstanceOf[T] - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") - } - } - private def testMultiFMA[T: Numeric: ClassTag]( dut: MultiFMA )( @@ -197,21 +21,10 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val peCount = dut.peCount val gemmType = dut.gemmType - // val fixedMatrix = Array( - // Array(4, 2, 3, 1) - // ) - // val fixedMatrix2 = Array( - // Array(4, 2, 3, 1), - // Array(0, 5, 1, 3), - // Array(4, 2, 1, 0), - // Array(0, 3, 1, 3) - // ) - // val matrixA_row = fixedMatrix - // val matrixB_cols = fixedMatrix2 val matrixA_row = matInit[T](1, k) val matrixB_cols = matInit[T](k, peCount) - val expectedResults = mmul(matrixA_row, matrixB_cols) + printmat(matrixA_row) printmat(matrixB_cols) printmat(expectedResults) @@ -224,11 +37,13 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe println("matrixA_row and matrixB_cols are ready") dut.io.matrixA_row.valid.poke(true.B) dut.io.matrixB_cols.valid.poke(true.B) - for (i <- matrixA_row(0).indices) { - for (j <- 0 until peCount) { - dut.io.matrixA_row.bits(i).poke(toBinaryBigInt(matrixA_row(0)(i)).U) - dut.io.matrixB_cols.bits(i)(j).poke(toBinaryBigInt(matrixB_cols(i)(j)).U) - } + + for { + i <- matrixA_row(0).indices + j <- 0 until peCount + } { + dut.io.matrixA_row.bits(i).poke(toBinaryBigInt(matrixA_row(0)(i)).U) + dut.io.matrixB_cols.bits(i)(j).poke(toBinaryBigInt(matrixB_cols(i)(j)).U) } } else { dut.io.matrixA_row.valid.poke(false.B) @@ -240,7 +55,6 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe } dut.io.blockResult.ready.poke(true.B) - val precision = 0.001f var invalidcnt = 0 @@ -248,22 +62,9 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val outBigInt = dut.io.blockResult.bits(i).peekInt() val out = fromBinaryBigInt[T](outBigInt) val expected = expectedResults(0)(i) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - }) - - if (isInvalid) { - println("Error: ") - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 + checkResult(out, expected, 0, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right } } @@ -285,6 +86,7 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val matrixA = matInit[T](m, k) val matrixB = matInit[T](k, n) val expectedResults = mmul(matrixA, matrixB) + printmat(matrixA) printmat(matrixB) printmat(expectedResults) @@ -293,13 +95,14 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe println("matrixA and matrixB are ready") dut.io.matrixA.valid.poke(true.B) dut.io.matrixB.valid.poke(true.B) - for (row <- 0 until m) { - for (col <- 0 until n) { - for (i <- 0 until k) { - dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) - dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) - } - } + + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) } } else { dut.io.matrixA.valid.poke(false.B) @@ -311,35 +114,22 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe } dut.io.results.ready.poke(true.B) - val precision = 0.001f var invalidcnt = 0 - for (row <- 0 until m) { - for (col <- 0 until n) { - val outBigInt = dut.io.results.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - }) - // printmat(Array(Array(out))) - // printmat(Array(Array(expected))) - if (isInvalid) { - println("Error: row: " + row + " col: " + col) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } + for { + row <- 0 until m + col <- 0 until n + } { + val outBigInt = dut.io.results.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right } } + if (invalidcnt == 0) println("Verification passed!") else println(s"Verification failed with $invalidcnt errors.") } @@ -358,19 +148,23 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val matrixA = matInit[T](m, k) val matrixB = matInit[T](k, n) val expectedResults = mmul(matrixA, matrixB) - // printmat(expectedResults) + + printmat(matrixA) + printmat(matrixB) + printmat(expectedResults) if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { println("matrixA and matrixB are ready") dut.io.matrixA.valid.poke(true.B) dut.io.matrixB.valid.poke(true.B) - for (row <- 0 until m) { - for (col <- 0 until n) { - for (i <- 0 until k) { - dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) - dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) - } - } + + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) } } else { dut.io.matrixA.valid.poke(false.B) @@ -383,34 +177,22 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe while (!dut.io.done.peekBoolean()) { if (dut.io.currentRow.valid.peekBoolean()) { val currentRowIndex = dut.io.currentRow.bits.index.peekInt() - println("currentRow index: " + currentRowIndex) + println(s"currentRow index: $currentRowIndex") + for (i <- 0 until n) { val outBigInt = dut.io.currentRow.bits.value(i).peekInt() val out = fromBinaryBigInt[T](outBigInt) val expected = expectedResults(currentRowIndex.toInt)(i) - println("i: " + i) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - }) - if (isInvalid) { - println("Error: ") - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 + + checkResult(out, expected, currentRowIndex.toInt, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right } } } dut.clock.step() } + if (invalidcnt == 0) println("Verification passed!") else println(s"Verification failed with $invalidcnt errors.") } @@ -428,19 +210,23 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val matrixA = matInit[T](m, k) val matrixB = matInit[T](k, n) val expectedResults = mmul(matrixA, matrixB) + + printmat(matrixA) + printmat(matrixB) printmat(expectedResults) if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { println("matrixA and matrixB are ready") dut.io.matrixA.valid.poke(true.B) dut.io.matrixB.valid.poke(true.B) - for (row <- 0 until m) { - for (col <- 0 until n) { - for (i <- 0 until k) { - dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) - dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) - } - } + + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) } } else { dut.io.matrixA.valid.poke(false.B) @@ -448,806 +234,41 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe } dut.io.currentRow.ready.poke(true.B) - val precision = 0.001f var invalidcnt = 0 while (!dut.io.done.peekBoolean()) { if (dut.io.currentRow.valid.peekBoolean()) { val currentRowIndex = dut.io.currentRow.bits.index.peekInt() - println("currentRow index: " + currentRowIndex) + println(s"currentRow index: $currentRowIndex") + for (i <- 0 until n) { val outBigInt = dut.io.currentRow.bits.value(i).peekInt() val out = fromBinaryBigInt[T](outBigInt) val expected = expectedResults(currentRowIndex.toInt)(i) - println("i: " + i) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - - }) - if (isInvalid) { - println("Error: ") - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } - dut.clock.step() - } - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") - } - - private def testQKGen[T: Numeric: ClassTag]( - dut: QKGen - )( - implicit config: DataWidthConfig - ): Unit = { - val m = dut.m - val k = dut.k - val n = dut.n - val gemmType = dut.gemmType - - val inputToken = matInit[T](m, k) - val weightQ = matInit[T](k, n) - val weightK = matInit[T](k, n) - val Query = mmul(inputToken, weightQ) - printmat(Query) - val Key = mmul(inputToken, weightK) - printmat(Key) - - if ( - dut.io.inputToken.ready.peekBoolean() && dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() - ) { - println("inputToken, weightQ and weightK are ready") - dut.io.inputToken.valid.poke(true.B) - dut.io.weightQ.valid.poke(true.B) - dut.io.weightK.valid.poke(true.B) - for (row <- 0 until m) { - for (col <- 0 until n) { - for (i <- 0 until k) { - dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) - dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) - dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) - } - } - } - } else { - dut.io.inputToken.valid.poke(false.B) - dut.io.weightQ.valid.poke(false.B) - dut.io.weightK.valid.poke(false.B) - } - - while (!(dut.io.Key.valid.peekBoolean() && dut.io.Query.valid.peekBoolean())) { - dut.clock.step() - } - - dut.io.Key.ready.poke(true.B) - dut.io.Query.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - for (row <- 0 until m) { - for (col <- 0 until n) { - val outBigInt = dut.io.Query.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = Query(row)(col) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - - }) - if (isInvalid) { - println("Error: row: " + row + " col: " + col) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } - - for (row <- 0 until m) { - for (col <- 0 until n) { - val outBigInt = dut.io.Key.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = Key(row)(col) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - - }) - if (isInvalid) { - println("Error: row: " + row + " col: " + col) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } - - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") - } - - private def testQKGenWithReg[T: Numeric: ClassTag]( - dut: QKGenWithReg - )( - implicit config: DataWidthConfig - ): Unit = { - val m = dut.m - val k = dut.k - val n = dut.n - val gemmType = dut.gemmType - - val inputToken = matInit[T](m, k) - val weightQ = matInit[T](k, n) - val weightK = matInit[T](k, n) - val Query = mmul(inputToken, weightQ) - printmat(Query) - val Key = mmul(inputToken, weightK) - printmat(Key) - - if ( - dut.io.inputToken.ready.peekBoolean() && dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() - ) { - println("inputToken, weightQ and weightK are ready") - dut.io.inputToken.valid.poke(true.B) - dut.io.weightQ.valid.poke(true.B) - dut.io.weightK.valid.poke(true.B) - for (row <- 0 until m) { - for (col <- 0 until n) { - for (i <- 0 until k) { - dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) - dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) - dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + + checkResult(out, expected, currentRowIndex.toInt, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right } } } - } else { - dut.io.inputToken.valid.poke(false.B) - dut.io.weightQ.valid.poke(false.B) - dut.io.weightK.valid.poke(false.B) - } - - while (!(dut.io.Key.valid.peekBoolean() && dut.io.Query.valid.peekBoolean())) { dut.clock.step() } - dut.io.Key.ready.poke(true.B) - dut.io.Query.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - for (row <- 0 until m) { - for (col <- 0 until n) { - val outBigInt = dut.io.Query.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = Query(row)(col) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - - }) - if (isInvalid) { - println("Error: row: " + row + " col: " + col) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } - - for (row <- 0 until m) { - for (col <- 0 until n) { - val outBigInt = dut.io.Key.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = Key(row)(col) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - - }) - if (isInvalid) { - println("Error: row: " + row + " col: " + col) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } - if (invalidcnt == 0) println("Verification passed!") else println(s"Verification failed with $invalidcnt errors.") } - private def testQKMulTotal[T: Numeric: ClassTag]( - dut: QKMulTotal - )( - implicit config: DataWidthConfig - ): Unit = { - val m = dut.m - val n = dut.n - val gemmType = dut.gemmType - // println("m: " + m + " n: " + n) - - // val fixedMatrix = Array( - // Array(4, -1, 3, 1), - // Array(0, 5, -3, 3), - // Array(4, -2, 4, 0), - // Array(0, 3, -1, 3) - // ) - // val Query = fixedMatrix - // val Key = fixedMatrix - val Query = matInit[T](m, n) - val Key = matInit[T](m, n) - val expectedResults = mmul(Query, Key.transpose) - println("Query:") - printmat(Query) - println("Key:") - printmat(Key) - println("expectedResults:") - printmat(expectedResults) - - if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { - println(" Query and Key are ready") - dut.io.Query.valid.poke(true.B) - dut.io.Key.valid.poke(true.B) - for (row <- 0 until m) { - for (col <- 0 until n) { - dut.io.Query.bits(row)(col).poke(toBinaryBigInt(Query(row)(col)).U) - dut.io.Key.bits(row)(col).poke(toBinaryBigInt(Key(row)(col)).U) - } + "GEMMSingleQueue " should "compute fxp matrix multiplication" in { + implicit val config: DataWidthConfig = FxpConfig + test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + testGEMMSingleQueue[Int](dut) } - } else { - dut.io.Query.valid.poke(false.B) - dut.io.Key.valid.poke(false.B) - } - - while (!dut.io.scores.valid.peekBoolean()) { - dut.clock.step() - } - - dut.io.scores.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - for (row <- 0 until m) { - for (col <- 0 until m) { - val outBigInt = dut.io.scores.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - - }) - if (isInvalid) { - println("Error: row: " + row + " col: " + col) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") } - private def testQKMulTotalWithReg[T: Numeric: ClassTag]( - dut: QKMulTotalWithReg - )( - implicit config: DataWidthConfig - ): Unit = { - val m = dut.m - val n = dut.n - val gemmType = dut.gemmType - // println("m: " + m + " n: " + n) - - // val fixedMatrix = Array( - // Array(4, -1, 3, 1), - // Array(0, 5, -3, 3), - // Array(4, -2, 4, 0), - // Array(0, 3, -1, 3) - // ) - // val Query = fixedMatrix - // val Key = fixedMatrix - val Query = matInit[T](m, n) - val Key = matInit[T](m, n) - val expectedResults = mmul(Query, Key.transpose) - - println("Query:") - printmat(Query) - println("Key:") - printmat(Key) - println("expectedResults:") - printmat(expectedResults) - - if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { - println(" Query and Key are ready") - dut.io.Query.valid.poke(true.B) - dut.io.Key.valid.poke(true.B) - for (row <- 0 until m) { - for (col <- 0 until n) { - dut.io.Query.bits(row)(col).poke(toBinaryBigInt(Query(row)(col)).U) - dut.io.Key.bits(row)(col).poke(toBinaryBigInt(Key(row)(col)).U) - } - } - } else { - dut.io.Query.valid.poke(false.B) - dut.io.Key.valid.poke(false.B) - } - - while (!dut.io.scores.valid.peekBoolean()) { - dut.clock.step() - } - - dut.io.scores.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - for (row <- 0 until m) { - for (col <- 0 until m) { - val outBigInt = dut.io.scores.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - - }) - if (isInvalid) { - println("Error: row: " + row + " col: " + col) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") - } - - private def testAttnScores[T: Numeric: ClassTag]( - dut: AttnScores - )( - implicit config: DataWidthConfig - ): Unit = { - val m = dut.m - val k = dut.k - val n = dut.n - val gemmType = dut.gemmType - - val inputToken = matInit[T](m, k) - val weightQ = matInit[T](k, n) - val weightK = matInit[T](k, n) - val Query = mmul(inputToken, weightQ) - printmat(Query) - val Key = mmul(inputToken, weightK) - printmat(Key.transpose) - val expectedResults = mmul(Query, Key.transpose) // Query * Key^T - printmat(expectedResults) - - if ( - dut.io.inputToken.ready.peekBoolean() && dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() - ) { - println("inputToken, weightQ and weightK are ready") - dut.io.inputToken.valid.poke(true.B) - dut.io.weightQ.valid.poke(true.B) - dut.io.weightK.valid.poke(true.B) - for (row <- 0 until m) { - for (col <- 0 until n) { - for (i <- 0 until k) { - dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) - dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) - dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) - } - } - } - } else { - dut.io.inputToken.valid.poke(false.B) - dut.io.weightQ.valid.poke(false.B) - dut.io.weightK.valid.poke(false.B) - } - - while (!dut.io.scores.valid.peekBoolean()) { - dut.clock.step() - } - - dut.io.scores.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - for (row <- 0 until m) { - for (col <- 0 until m) { - val outBigInt = dut.io.scores.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - - }) - if (isInvalid) { - println("Error: row: " + row + " col: " + col) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") - } - - private def testAttnScoresTotal[T: Numeric: ClassTag]( - dut: AttnScoresTotal - )( - implicit config: DataWidthConfig - ): Unit = { - val m = dut.m - val k = dut.k - val n = dut.n - val gemmType = dut.gemmType - - val inputToken = matInit[T](m, k) - val weightQ = matInit[T](k, n) - val weightK = matInit[T](k, n) - val Query = mmul(inputToken, weightQ) - printmat(Query) - val Key = mmul(inputToken, weightK) - printmat(Key.transpose) - val expectedResults = mmul(Query, Key.transpose) // Query * Key^T - printmat(expectedResults) - - if ( - dut.io.inputToken.ready.peekBoolean() && dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() - ) { - println("inputToken, weightQ and weightK are ready") - dut.io.inputToken.valid.poke(true.B) - dut.io.weightQ.valid.poke(true.B) - dut.io.weightK.valid.poke(true.B) - for (row <- 0 until m) { - for (col <- 0 until n) { - for (i <- 0 until k) { - dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) - dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) - dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) - } - } - } - } else { - dut.io.inputToken.valid.poke(false.B) - dut.io.weightQ.valid.poke(false.B) - dut.io.weightK.valid.poke(false.B) - } - - while (!dut.io.scores.valid.peekBoolean()) { - dut.clock.step() - } - - dut.io.scores.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - for (row <- 0 until m) { - for (col <- 0 until m) { - val outBigInt = dut.io.scores.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - - }) - if (isInvalid) { - println("Error: row: " + row + " col: " + col) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") - } - - private def testOutValue[T: Numeric: ClassTag]( - dut: OutValue - )( - implicit config: DataWidthConfig - ): Unit = { - val m = dut.m - val n = dut.n - val gemmType = dut.gemmType - - val AttnWeights = matInit[T](m, m) - val Value = matInit[T](m, n) - val expectedResults = mmul(AttnWeights, Value) - - if (dut.io.AttnWeights.ready.peekBoolean() && dut.io.Value.ready.peekBoolean()) { - println("AttnWeights and Value are ready") - dut.io.AttnWeights.valid.poke(true.B) - dut.io.Value.valid.poke(true.B) - for (row <- 0 until m) { - for (col <- 0 until n) { - for (i <- 0 until m) { - dut.io.AttnWeights.bits(row)(i).poke(toBinaryBigInt(AttnWeights(row)(i)).U) - dut.io.Value.bits(i)(col).poke(toBinaryBigInt(Value(i)(col)).U) - } - } - } - } else { - dut.io.AttnWeights.valid.poke(false.B) - dut.io.Value.valid.poke(false.B) - } - - while (!dut.io.AttnOut.valid.peekBoolean()) { - dut.clock.step() - } - - dut.io.AttnOut.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - for (row <- 0 until m) { - for (col <- 0 until n) { - val outBigInt = dut.io.AttnOut.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - - }) - if (isInvalid) { - println("Error: row: " + row + " col: " + col) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - } - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") - } - - private def testOutValueSingle[T: Numeric: ClassTag]( - dut: OutValueSingle - )( - implicit config: DataWidthConfig - ): Unit = { - val m = dut.m - val n = dut.n - val gemmType = dut.gemmType - - val AttnWeights = matInit[T](m, m) - val Value = matInit[T](m, n) - val expectedResults = mmul(AttnWeights, Value) - printmat(AttnWeights) - printmat(Value) - printmat(expectedResults) - - val precision = 0.001f - var invalidcnt = 0 - for (index <- 0 until m) { - println("index: " + index) - if (dut.io.currentAttnW.ready.peekBoolean() && dut.io.Value.ready.peekBoolean()) { - println("currentAttnW index :" + index + " and Value are ready") - - dut.io.currentAttnW.valid.poke(true.B) - dut.io.Value.valid.poke(true.B) - for (i <- 0 until m) { - dut.io.currentAttnW.bits.value(i).poke(toBinaryBigInt(AttnWeights(index)(i)).U) - for (j <- 0 until n) { - dut.io.Value.bits(i)(j).poke(toBinaryBigInt(Value(i)(j)).U) - } - } - - } else { - dut.io.currentAttnW.valid.poke(false.B) - dut.io.Value.valid.poke(false.B) - } - while (!dut.io.currentAttnO.valid.peekBoolean()) { - dut.io.currentAttnO.ready.poke(false.B) - dut.clock.step() - } - - dut.io.currentAttnO.ready.poke(true.B) - - val currentRowIndex = dut.io.currentAttnO.bits.index.peekInt() - // println("currentRow index:" + currentRowIndex + " expected: " + index) - for (i <- 0 until n) { - val outBigInt = dut.io.currentAttnO.bits.value(i).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(currentRowIndex.toInt)(i) - - val isInvalid = (implicitly[ClassTag[T]].runtimeClass match { - case c if c == classOf[Float] => - math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision - case c if c == classOf[Double] => - math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision - - case c if c == classOf[Int] => - math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision - case _ => - throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") - }) - if (isInvalid) { - println("Error: " + i) - printmat(Array(Array(out))) - printmat(Array(Array(expected))) - invalidcnt += 1 - } - } - dut.clock.step() - - } - - // while (!dut.io.done.peekBoolean()) { - // dut.clock.step() - // } - - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") - } - - // ===--::--=== - // below tests ERROR - // ===--::--=== - - - // "AttnScoresTotal " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new AttnScoresTotal(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testAttnScoresTotal[Int](dut) - // } - // } - - // "AttnScores " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new AttnScores(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testAttnScores[Int](dut) - // } - // } - - // "AttnScores " should "compute fp32 matrix multiplication" in { - // implicit val config: DataWidthConfig = Fp32Config - // test(new AttnScores(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fp32)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testAttnScores[Float](dut) - // } - // } - - // ===--::--=== - // below tests PASS - // ===--::--=== - - // "OutValueSingle " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new OutValueSingle(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testOutValueSingle[Int](dut) - // } - // } - - // "OutValue " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new OutValue(m = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testOutValue[Int](dut) - // } - // } - - // "QKMulTotal " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new QKMulTotal(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testQKMulTotal[Int](dut) - // } - // } - - // "QKMulTotalWithReg " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new QKMulTotalWithReg(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testQKMulTotalWithReg[Int](dut) - // } - // } - - // "QKGen " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new QKGen(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testQKGen[Int](dut) - // } - // } - - // "QKGenWithReg " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new QKGenWithReg(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testQKGenWithReg[Int](dut) - // } - // } - - // "GEMMSingleQueue " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testGEMMSingleQueue[Int](dut) - // } - // } - // "GEMMSingleQueue " should "compute fp32 matrix multiplication" in { // implicit val config: DataWidthConfig = Fp32Config // test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fp32)) @@ -1258,7 +279,7 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe // "GEMMFMATotal " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new GEMMFMATotal(m = 4, k = 4, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // test(new GEMMFMATotal(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => // testGEMMFMATotal[Int](dut) // } diff --git a/src/test/scala/kernel/alu/OutValueTest.scala b/src/test/scala/kernel/alu/OutValueTest.scala new file mode 100644 index 0000000..a48a381 --- /dev/null +++ b/src/test/scala/kernel/alu/OutValueTest.scala @@ -0,0 +1,162 @@ +package kernel.alu + +import chisel3._ +import chiseltest._ +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.ParallelTestExecution +import scala.reflect.ClassTag +import kernel.alu.{DataWidthConfig, Fp32Config, Fp64Config, FxpConfig, GEMMDataType} +import ujson.Arr +import Utils._ + +class OutValueTest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestExecution { + + private def testOutValue[T: Numeric: ClassTag]( + dut: OutValue + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + + val Scores = matInit[T](m, m) + val Value = matInit[T](m, n) + val expectedResults = mmul(Scores, Value) + + printmat(Scores) + printmat(Value) + printmat(expectedResults) + + if (dut.io.Scores.ready.peekBoolean() && dut.io.Value.ready.peekBoolean()) { + println("Scores and Value are ready") + dut.io.Scores.valid.poke(true.B) + dut.io.Value.valid.poke(true.B) + + for { + row <- 0 until m + col <- 0 until n + i <- 0 until m + } { + dut.io.Scores.bits(row)(i).poke(toBinaryBigInt(Scores(row)(i)).U) + dut.io.Value.bits(i)(col).poke(toBinaryBigInt(Value(i)(col)).U) + } + } else { + dut.io.Scores.valid.poke(false.B) + dut.io.Value.valid.poke(false.B) + } + + while (!dut.io.AttnOut.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.AttnOut.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + + for { + row <- 0 until m + col <- 0 until n + } { + val outBigInt = dut.io.AttnOut.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testOutValueSingle[T: Numeric: ClassTag]( + dut: OutValueSingle + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + + val AttnWeights = matInit[T](m, m) + val Value = matInit[T](m, n) + val expectedResults = mmul(AttnWeights, Value) + + printmat(AttnWeights) + printmat(Value) + printmat(expectedResults) + + val precision = 0.001f + var invalidcnt = 0 + + if (dut.io.Value.ready.peekBoolean()) { + println("Value is ready") + dut.io.Value.valid.poke(true.B) + + for { + i <- 0 until m + j <- 0 until n + } { + dut.io.Value.bits(i)(j).poke(toBinaryBigInt(Value(i)(j)).U) + } + } else { + dut.io.Value.valid.poke(false.B) + } + + for (index <- 0 until m) { + if (dut.io.currentScores.ready.peekBoolean()) { + println(s"currentScores index: $index is ready") + dut.io.currentScores.valid.poke(true.B) + + for (i <- 0 until m) { + dut.io.currentScores.bits.value(i).poke(toBinaryBigInt(AttnWeights(index)(i)).U) + } + } else { + dut.io.currentScores.valid.poke(false.B) + } + + dut.io.currentAttnOut.ready.poke(false.B) + while (!dut.io.currentAttnOut.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.currentAttnOut.ready.poke(true.B) + val currentRowIndex = dut.io.currentAttnOut.bits.index.peekInt() + + for (i <- 0 until n) { + val outBigInt = dut.io.currentAttnOut.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(currentRowIndex.toInt)(i) + + checkResult(out, expected, currentRowIndex.toInt, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + dut.clock.step() + } + + dut.io.done.expect(true.B) + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + "OutValueSingle " should "compute fxp matrix multiplication" in { + implicit val config: DataWidthConfig = FxpConfig + test(new OutValueSingle(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + testOutValueSingle[Int](dut) + } + } + + // "OutValue " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new OutValue(m = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testOutValue[Int](dut) + // } + // } +} diff --git a/src/test/scala/kernel/alu/utils.scala b/src/test/scala/kernel/alu/utils.scala new file mode 100644 index 0000000..c61af5f --- /dev/null +++ b/src/test/scala/kernel/alu/utils.scala @@ -0,0 +1,214 @@ +package kernel.alu +import chisel3._ +import chiseltest._ +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.ParallelTestExecution +import scala.reflect.ClassTag +import kernel.alu.{DataWidthConfig, Fp32Config, Fp64Config, FxpConfig, GEMMDataType} +import ujson.Arr + +object Utils { + def mmul[T: Numeric: ClassTag](a: Array[Array[T]], b: Array[Array[T]]): Array[Array[T]] = { + val rows = a.length + val cols = b(0).length + val n = b.length + val num = implicitly[Numeric[T]] + + Array.tabulate(rows, cols) { (i, j) => + var sum = num.zero + for (k <- 0 until n) { + sum = num.plus(sum, num.times(a(i)(k), b(k)(j))) + } + sum + } + } + + def matInit[T: Numeric: ClassTag]( + rows: Int, + cols: Int + )( + implicit config: DataWidthConfig + ): Array[Array[T]] = { + val r = new scala.util.Random(42) + val ct = implicitly[ClassTag[T]] + val numeric = implicitly[Numeric[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + // 定点数使用 -8 到 7 的整数 + Array.fill(rows, cols)( + numeric.fromInt( + // r.nextInt(math.pow(2, config.inputWidth).toInt) - math.pow(2, config.inputWidth - 1).toInt + r.nextInt(4) - 2 + ) + ) + case c if c == classOf[Float] => + // 32位浮点数使用 -1 到 1 的随机浮点数 + // Float 类型 + Array.fill(rows, cols)((r.nextFloat() * 2 - 1).asInstanceOf[T]) + case c if c == classOf[Double] => + // 64位浮点数使用 -1 到 1 的随机浮点数 + Array.fill(rows, cols)((r.nextDouble() * 2 - 1).asInstanceOf[T]) + case _ => + throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") + } + } + + def toSignedBigInt(value: BigInt, width: Int): BigInt = { + val signBit = (value >> (width - 1)) & 1 + + if (signBit == 1) { + val maxValue = BigInt(1) << width + value - maxValue + } else { + value + } + } + + def printmat[T: Numeric: ClassTag](m: Array[Array[T]]): Unit = { + val numeric = implicitly[Numeric[T]] + val ct = implicitly[ClassTag[T]] + + m.foreach { r => + r.foreach { v => + ct.runtimeClass match { + case c if c == classOf[Float] => + print(f"${v.asInstanceOf[Float]}%.4f\t") + case c if c == classOf[Double] => + print(f"${v.asInstanceOf[Double]}%.4f\t") + case c if c == classOf[Int] => + print(f"${v.asInstanceOf[Int]}%d\t") + case _ => + throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") + } + } + println(";") + } + println() + } + + def printmat[T: Numeric: ClassTag](m: Array[T], x: Int, y: Int)(implicit config: DataWidthConfig): Unit = { + val numeric = implicitly[Numeric[T]] + val ct = implicitly[ClassTag[T]] + + for (i <- 0 until x) { + for (j <- 0 until y) { + ct.runtimeClass match { + case c if c == classOf[Float] => + print(f"${m(i * y + j).asInstanceOf[Float]}%.4f\t") + case c if c == classOf[Double] => + print(f"${m(i * y + j).asInstanceOf[Double]}%.4f\t") + case c if c == classOf[Int] => + print(f"${m(i * y + j).asInstanceOf[Int]}%d\t") + case c if c == classOf[BigInt] => + print(f"${toSignedBigInt(m(i * y + j).asInstanceOf[BigInt], config.inputWidth)}%d\t") + case _ => + throw new IllegalArgumentException(s"不支持的数据类型: ${ct.runtimeClass}") + } + } + println(";") + } + println() + } + + // convert T to binary bigInt + def toBinaryBigInt[T: Numeric: ClassTag](v: T)(implicit config: DataWidthConfig): BigInt = { + val ct = implicitly[ClassTag[T]] + val num = implicitly[Numeric[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + val intValue = v.asInstanceOf[Int] + // 使用 inputWidth 位来表示所有整数,保持符号位 + val mask = (1L << config.inputWidth) - 1 + BigInt(intValue) & mask + case c if c == classOf[Float] => + BigInt(java.lang.Float.floatToRawIntBits(v.asInstanceOf[Float]).toBinaryString, 2) + case c if c == classOf[Double] => + BigInt(java.lang.Double.doubleToRawLongBits(v.asInstanceOf[Double]).toBinaryString, 2) + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") + } + } + + // convrt T to binary string + def toBinaryString[T: Numeric: ClassTag](v: T)(implicit config: DataWidthConfig): String = { + val ct = implicitly[ClassTag[T]] + val num = implicitly[Numeric[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + val intBValue = v.asInstanceOf[Int].toBinaryString + if (intBValue.length < config.inputWidth) { + intBValue.reverse.padTo(config.inputWidth, '0').reverse + } else { + intBValue.takeRight(config.inputWidth) + } + case c if c == classOf[Float] => + java.lang.Float + .floatToRawIntBits(v.asInstanceOf[Float]) + .toBinaryString + .reverse + .padTo(config.inputWidth, '0') + .reverse + case c if c == classOf[Double] => + java.lang.Double + .doubleToRawLongBits(v.asInstanceOf[Double]) + .toBinaryString + .reverse + .padTo(config.inputWidth, '0') + .reverse + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") + } + } + + // convert binary bigInt to T + def fromBinaryBigInt[T: Numeric: ClassTag](bigInt: BigInt)(implicit config: DataWidthConfig): T = { + val ct = implicitly[ClassTag[T]] + + ct.runtimeClass match { + case c if c == classOf[Int] => + val intValue = bigInt.toInt + // 处理符号位 + val signExtendedValue = if ((intValue & (1 << (config.inputWidth - 1))) != 0) { + intValue | ~((1 << config.inputWidth) - 1) + } else { + intValue + } + signExtendedValue.asInstanceOf[T] + case c if c == classOf[Float] => + java.lang.Float.intBitsToFloat(bigInt.toInt).asInstanceOf[T] + case c if c == classOf[Double] => + java.lang.Double.longBitsToDouble(bigInt.toLong).asInstanceOf[T] + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${ct.runtimeClass}") + } + } + + def checkResult[T: Numeric: ClassTag]( + out: T, + expected: T, + row: Int, + col: Int, + precision: Float + ): Option[Unit] = { + val isInvalid = implicitly[ClassTag[T]].runtimeClass match { + case c if c == classOf[Float] => + math.abs(out.asInstanceOf[Float] - expected.asInstanceOf[Float]) > precision + case c if c == classOf[Double] => + math.abs(out.asInstanceOf[Double] - expected.asInstanceOf[Double]) > precision + case c if c == classOf[Int] => + math.abs(out.asInstanceOf[Int] - expected.asInstanceOf[Int]) > precision + case _ => + throw new IllegalArgumentException(s"Unsupported type: ${implicitly[ClassTag[T]].runtimeClass}") + } + + if (isInvalid) { + println(s"Error: row: $row col: $col") + printmat(Array(Array(out))) + printmat(Array(Array(expected))) + Some(()) + } else None + } +} From 0a0ab2b2c698c7ff1b40df3a57343bd785c347ea Mon Sep 17 00:00:00 2001 From: pyfirstcsh Date: Mon, 13 Jan 2025 23:27:47 +0800 Subject: [PATCH 06/10] Single FxpError &&SingleQueue Error --- src/main/scala/kernel/alu/AttnScores.scala | 314 +++++++++++++++++- src/main/scala/kernel/alu/GemmFMA.scala | 197 ++++++----- src/main/scala/kernel/alu/OutValue.scala | 34 +- .../scala/kernel/alu/AttnScoresTest.scala | 278 ++++++++++++++-- src/test/scala/kernel/alu/GemmFMATest.scala | 156 +++++++-- src/test/scala/kernel/alu/OutValueTest.scala | 24 +- 6 files changed, 831 insertions(+), 172 deletions(-) diff --git a/src/main/scala/kernel/alu/AttnScores.scala b/src/main/scala/kernel/alu/AttnScores.scala index 824851d..b5dd89a 100644 --- a/src/main/scala/kernel/alu/AttnScores.scala +++ b/src/main/scala/kernel/alu/AttnScores.scala @@ -269,8 +269,6 @@ class QKMul( io.scores.valid := false.B io.scores.bits := DontCare - val doneReg = RegInit(false.B) - val QK_TMul = Module(new GEMMFMATotal(m, n, m, peCount, gemmType)) QK_TMul.io.matrixA.valid := io.Query.valid @@ -292,11 +290,13 @@ class QKMul( } } is(state.mul) { + QK_TMul.io.results.ready := true.B when(QK_TMul.io.results.valid) { stateReg := state.done } } is(state.done) { + QK_TMul.io.results.ready := false.B readyReg := true.B io.scores.valid := true.B io.scores.bits := QK_TMul.io.results.bits @@ -304,6 +304,83 @@ class QKMul( } } } +// QKMulSingle: use GEMMFMASingle to get scores by row +// input: Query: m * n +// input: Key: m * n +// output: scores: 1 * m +// output: done: Bool +class QKMulSingle( + val m: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val Query = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val Key = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) + val curRowScores = Decoupled(new curRowIndex(m, m)) + val done = Output(Bool()) + }) + + val dataValid = io.Query.valid && io.Key.valid + + val readyReg = RegInit(true.B) + io.Query.ready := readyReg + io.Key.ready := readyReg + io.curRowScores.valid := false.B + io.curRowScores.bits := DontCare + io.done := false.B + + val curRowIndexReg = Reg(new curRowIndex(m, m)) + + val QK_TMul = Module(new GEMMFMASingle(m, n, m, peCount, gemmType)) + // val QK_TMul = Module(new GEMMSingleQueue(m, n, m, peCount, gemmType)) + + QK_TMul.io.matrixA.valid := io.Query.valid + QK_TMul.io.matrixA.bits := io.Query.bits + QK_TMul.io.matrixB.valid := io.Key.valid + QK_TMul.io.matrixB.bits := VecInit(io.Key.bits.transpose.map(VecInit(_))) + QK_TMul.io.curRow.ready := false.B + + object state extends ChiselEnum { + val idle, mul, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.mul + } + } + is(state.mul) { + QK_TMul.io.curRow.ready := true.B + when(QK_TMul.io.curRow.valid) { + curRowIndexReg := QK_TMul.io.curRow.bits + stateReg := state.update + } + } + is(state.update) { + QK_TMul.io.curRow.ready := false.B + io.curRowScores.valid := true.B + io.curRowScores.bits := curRowIndexReg + when(QK_TMul.io.done) { + stateReg := state.done + }.otherwise { + stateReg := state.mul + } + } + is(state.done) { + readyReg := true.B + io.done := true.B + stateReg := state.idle + } + } +} // AttnScores: use QKGen to get Q and K, then use QKMul to get scores // input: inputToken: m * k @@ -398,3 +475,236 @@ class AttnScores( } } } + +// AttnScoresSingle: use QKGen to get Q and K, get scores by row +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// output: curRowScores: 1 * m +// output: done: Bool +class AttnScoresSingle( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val curRowScores = Decoupled(new curRowIndex(m, m)) + val done = Output(Bool()) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + val readyReg = RegInit(true.B) + io.inputToken.ready := readyReg + io.weightQ.ready := readyReg + io.weightK.ready := readyReg + io.curRowScores.valid := false.B + io.curRowScores.bits := DontCare + io.done := false.B + + // val QKGen = Module(new QKGenWithReg(m, k, n, peCount, gemmType)) + val QKGen = Module(new QKGen(m, k, n, peCount, gemmType)) + + QKGen.io.inputToken.valid := io.inputToken.valid + QKGen.io.inputToken.bits := io.inputToken.bits + QKGen.io.weightQ.valid := io.weightQ.valid + QKGen.io.weightQ.bits := io.weightQ.bits + QKGen.io.weightK.valid := io.weightK.valid + QKGen.io.weightK.bits := io.weightK.bits + + QKGen.io.Query.ready := false.B + QKGen.io.Key.ready := false.B + + val curRowIndexReg = Reg(new curRowIndex(m, m)) + val QKMul = Module(new GEMMFMASingle(m, n, m, peCount, gemmType)) + + QKMul.io.matrixA.valid := QKGen.io.Query.valid + QKMul.io.matrixA.bits := QKGen.io.Query.bits + QKMul.io.matrixB.valid := QKGen.io.Key.valid + QKMul.io.matrixB.bits := VecInit(QKGen.io.Key.bits.transpose.map(VecInit(_))) + QKMul.io.curRow.ready := false.B + + object state extends ChiselEnum { + val idle, gen, mul, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.gen + } + } + is(state.gen) { + QKGen.io.Query.ready := true.B + QKGen.io.Key.ready := true.B + when(QKGen.io.Query.valid && QKGen.io.Key.valid) { + stateReg := state.mul + } + } + is(state.mul) { + QKGen.io.Query.ready := false.B + QKGen.io.Key.ready := false.B + QKMul.io.curRow.ready := true.B + // printf(p"QKGen.io.Query.bits: ${QKGen.io.Query.bits}\n") + // printf(p"QKGen.io.Key.bits: ${QKGen.io.Key.bits}\n") + // printf(p"QKMul.io.matrixA.bits: ${QKMul.io.matrixA.bits}\n") + // printf(p"QKMul.io.matrixB.bits: ${QKMul.io.matrixB.bits}\n") + when(QKMul.io.curRow.valid) { + curRowIndexReg := QKMul.io.curRow.bits + printf(p"QKMul.io.curRow.bits: ${QKMul.io.curRow.bits}\n") + printf(p"curRowIndexReg: ${curRowIndexReg}\n") + stateReg := state.update + } + } + is(state.update) { + QKMul.io.curRow.ready := false.B + io.curRowScores.valid := true.B + io.curRowScores.bits := curRowIndexReg + printf(p"Update curRowIndexReg: ${curRowIndexReg}\n") + printf(p"Update io.curRowScores.bits: ${io.curRowScores.bits}\n") + printf(p"Update io.curRowScores.valid: ${io.curRowScores.valid}\n") + when(QKMul.io.done) { + stateReg := state.done + }.otherwise { + stateReg := state.mul + } + } + is(state.done) { + readyReg := true.B + io.done := true.B + stateReg := state.idle + } + } +} + +// AttnScoresSingleQueue: use QKGen to get Q and K, get scores by row +// input: inputToken: m * k +// input: weightQ: k * n +// input: weightK: k * n +// input: flush: Bool +// output: curRowScores: 1 * m +// output: done: Bool +class AttnScoresSingleQueue( + val m: Int, + val k: Int, + val n: Int, + val peCount: Int = 16, + val gemmType: GEMMDataType.Type, + val bufferSize: Int = 32 +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val inputToken = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val weightQ = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val weightK = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val flush = Input(Bool()) + val curRowScores = Decoupled(new curRowIndex(m, m)) + val done = Output(Bool()) + }) + + val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid + + val readyReg = RegInit(true.B) + io.inputToken.ready := readyReg + io.weightQ.ready := readyReg + io.weightK.ready := readyReg + io.curRowScores.valid := false.B + io.curRowScores.bits := DontCare + io.done := false.B + + // val QKGen = Module(new QKGenWithReg(m, k, n, peCount, gemmType)) + val QKGen = Module(new QKGen(m, k, n, peCount, gemmType)) + + QKGen.io.inputToken.valid := io.inputToken.valid + QKGen.io.inputToken.bits := io.inputToken.bits + QKGen.io.weightQ.valid := io.weightQ.valid + QKGen.io.weightQ.bits := io.weightQ.bits + QKGen.io.weightK.valid := io.weightK.valid + QKGen.io.weightK.bits := io.weightK.bits + + QKGen.io.Query.ready := false.B + QKGen.io.Key.ready := false.B + + val curRowIndexReg = Reg(new curRowIndex(m, m)) + + val QKMul = Module(new GEMMFMASingle(m, n, m, peCount, gemmType)) + + QKMul.io.matrixA.valid := QKGen.io.Query.valid + QKMul.io.matrixA.bits := QKGen.io.Query.bits + QKMul.io.matrixB.valid := QKGen.io.Key.valid + QKMul.io.matrixB.bits := VecInit(QKGen.io.Key.bits.transpose.map(VecInit(_))) + QKMul.io.curRow.ready := false.B + + val curBuffer = Module( + new Queue( + new curRowIndex(m, n), + entries = bufferSize, + pipe = true, + flow = false, + useSyncReadMem = false, + hasFlush = true + ) + ) + curBuffer.io.flush.get := io.flush + curBuffer.io.enq <> QKMul.io.curRow + io.curRowScores <> curBuffer.io.deq + + object state extends ChiselEnum { + val idle, gen, mul, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(dataValid) { + readyReg := false.B + stateReg := state.gen + } + } + is(state.gen) { + QKGen.io.Query.ready := true.B + QKGen.io.Key.ready := true.B + when(QKGen.io.Query.valid && QKGen.io.Key.valid) { + stateReg := state.mul + } + } + is(state.mul) { + QKGen.io.Query.ready := false.B + QKGen.io.Key.ready := false.B + QKMul.io.curRow.ready := true.B + // printf(p"QKGen.io.Query.bits: ${QKGen.io.Query.bits}\n") + // printf(p"QKGen.io.Key.bits: ${QKGen.io.Key.bits}\n") + when(QKMul.io.curRow.valid) { + // curRowIndexReg := QKMul.io.curRow.bits + stateReg := state.update + } + } + is(state.update) { + QKMul.io.curRow.ready := false.B + io.curRowScores.valid := true.B + // io.curRowScores.bits := curRowIndexReg + when(QKMul.io.done) { + stateReg := state.done + }.otherwise { + stateReg := state.mul + } + } + is(state.done) { + readyReg := true.B + io.done := true.B + stateReg := state.idle + } + } +} diff --git a/src/main/scala/kernel/alu/GemmFMA.scala b/src/main/scala/kernel/alu/GemmFMA.scala index 55ac1f3..4335258 100644 --- a/src/main/scala/kernel/alu/GemmFMA.scala +++ b/src/main/scala/kernel/alu/GemmFMA.scala @@ -7,7 +7,7 @@ import kernel.alu.DataWidthConfig import kernel.utils.DebugLog import kernel.deprecated.PE -class currentRowIndex( +class curRowIndex( val m: Int, val n: Int )( @@ -77,72 +77,87 @@ class MultiFMA( optIndex := optIndex + 1.U } - // TODO: FSM :reset logic is not correct, need to be fixed - - // pes.foreach { pe => - // pe.in_h := 0.U - // pe.in_v := 0.U - // pe.reset := DontCare - // } - - // io.blockResult.valid := validReg - // io.blockResult.bits := DontCare - - // object state extends ChiselEnum { - // val idle, reset, compute, update, done = Value - // } - // val stateReg = RegInit(state.idle) - - // switch(stateReg) { - // is(state.idle) { - // when(dataValid) { - // readyReg := false.B - // stateReg := state.compute - // } - // } - // is(state.compute) { - // when(io.reset) { - // stateReg := state.reset - // } - // for (i <- 0 until peCount) { - // pes(i).reset := io.reset - // pes(i).in_h := io.matrixA_row.bits(optIndex) - // pes(i).in_v := io.matrixB_cols.bits(optIndex)(i) - // io.blockResult.bits(i) := pes(i).out - // } - - // // printf(p"optIndex: ${optIndex}\n") - // // printf(p"io.matrixA_row.bits(${optIndex}): ${io.matrixA_row.bits(optIndex)}\n") - // // for (i <- 0 until peCount) { - // // printf(p"pe: $i\n") - // // printf(p"io.matrixB_cols.bits(${optIndex})($i): ${io.matrixB_cols.bits(optIndex)(i)}\n") - // // printf(p"io.blockResult.bits(${i}): ${io.blockResult.bits(i)}\n") - // // } - // stateReg := state.update - // } - // is(state.reset) { - // optIndex := 0.U - // validReg := false.B - // stateReg := state.idle - // } - // is(state.update) { - // validReg := false.B - // when(optIndex === (k - 1).U) { - // stateReg := state.done - // }.otherwise { - // optIndex := optIndex + 1.U - // stateReg := state.compute - // } - // } - // is(state.done) { - // optIndex := 0.U - // readyReg := true.B - // validReg := true.B - // stateReg := state.idle - // } - // } } +class MultiFMA_v2( + val k: Int, + val peCount: Int, + val gemmType: GEMMDataType.Type +)( + implicit config: DataWidthConfig) + extends Module + with DebugLog { + val io = IO(new Bundle { + val matrixA_row = Flipped(Decoupled(Vec(k, UInt(config.inputWidth.W)))) + val matrixB_cols = Flipped(Decoupled(Vec(k, Vec(peCount, UInt(config.inputWidth.W))))) + val blockResult = Decoupled(Vec(peCount, UInt(config.outputWidth.W))) + val reset = Input(Bool()) + }) + + val dataValid = io.matrixA_row.valid && io.matrixB_cols.valid + + val readyReg = RegInit(true.B) + io.matrixA_row.ready := readyReg + io.matrixB_cols.ready := readyReg + io.blockResult.valid := false.B + io.blockResult.bits := DontCare + + val pes = Seq.fill(peCount)(gemmType match { + case GEMMDataType.Fxp => Module(new PEFxp()).io + case GEMMDataType.Fp32 => Module(new PEFp()).io + case GEMMDataType.Fp64 => Module(new PEFp()).io + case _ => throw new IllegalArgumentException("Unsupported GEMM type") + }) + + val optIndex = Counter(k) + + pes.foreach { pe => + pe.in_h := 0.U + pe.in_v := 0.U + pe.reset := DontCare + } + + object state extends ChiselEnum { + val idle, compute, update, done = Value + } + val stateReg = RegInit(state.idle) + + switch(stateReg) { + is(state.idle) { + when(io.reset) { + optIndex.reset() + for (i <- 0 until peCount) { + pes(i).reset := true.B + } + } + when(dataValid) { + readyReg := false.B + stateReg := state.compute + } + } + is(state.compute) { + for (i <- 0 until peCount) { + pes(i).reset := false.B + pes(i).in_h := io.matrixA_row.bits(optIndex.value) + pes(i).in_v := io.matrixB_cols.bits(optIndex.value)(i) + io.blockResult.bits(i) := pes(i).out + } + stateReg := state.update + } + is(state.update) { + when(optIndex.inc()) { + stateReg := state.done + }.otherwise { + stateReg := state.compute + } + } + is(state.done) { + readyReg := true.B + io.blockResult.valid := true.B + stateReg := state.idle + } + } +} // input: matrixA: m * k // input: matrixB: k * n // output: matrixC: m * n @@ -171,7 +186,7 @@ class GEMMFMATotal( io.results.valid := false.B io.results.bits := DontCare - val multiFMA = Module(new MultiFMA(k, peCount, gemmType)) + val multiFMA = Module(new MultiFMA_v2(k, peCount, gemmType)) val rowIndex = Counter(m) val colIndex = Counter(n / peCount) @@ -182,7 +197,7 @@ class GEMMFMATotal( multiFMA.io.matrixB_cols.valid := io.matrixB.valid multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(k) { j => VecInit(Seq.tabulate(peCount) { i => - io.matrixB.bits(j)((colIndex.value * peCount.U + i.U) % n.U) + io.matrixB.bits(j)((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) }) }) //k * peCount size block of matrixB @@ -209,7 +224,8 @@ class GEMMFMATotal( multiFMA.io.blockResult.ready := true.B when(multiFMA.io.blockResult.valid) { for (i <- 0 until peCount) { - resultsReg(rowIndex.value)((colIndex.value * peCount.U + i.U) % n.U) := multiFMA.io.blockResult.bits(i) + resultsReg(rowIndex.value)((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) := multiFMA.io.blockResult + .bits(i) } stateReg := state.update } @@ -241,7 +257,7 @@ class GEMMFMATotal( //input: matrixA: m * k //input: matrixB: k * n -//output: currentRowIndex: one row of matrixC: 1 * n and current row index +//output: curRowIndex: one row of matrixC: 1 * n and cur row index //output: done: total matrixC finish flag class GEMMFMASingle( val m: Int, @@ -257,7 +273,7 @@ class GEMMFMASingle( val io = IO(new Bundle { val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) - val currentRow = Decoupled(new currentRowIndex(m, n)) + val curRow = Decoupled(new curRowIndex(m, n)) val done = Output(Bool()) }) @@ -265,8 +281,8 @@ class GEMMFMASingle( val readyReg = RegInit(true.B) io.matrixA.ready := readyReg io.matrixB.ready := readyReg - io.currentRow.valid := false.B - io.currentRow.bits := DontCare + io.curRow.valid := false.B + io.curRow.bits := DontCare io.done := false.B val multiFMA = Module(new MultiFMA(k, peCount, gemmType)) @@ -280,14 +296,14 @@ class GEMMFMASingle( multiFMA.io.matrixB_cols.valid := io.matrixB.valid multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(k) { j => VecInit(Seq.tabulate(peCount) { i => - io.matrixB.bits(j)((colIndex.value * peCount.U + i.U) % n.U) + io.matrixB.bits(j)((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) }) }) //k * peCount size block of matrixB - multiFMA.io.reset := false.B - multiFMA.io.blockResult.ready := true.B + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B - val currentRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) + val curRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) object state extends ChiselEnum { val idle, compute, update, done = Value @@ -304,9 +320,10 @@ class GEMMFMASingle( is(state.compute) { multiFMA.io.reset := false.B + multiFMA.io.blockResult.ready := true.B when(multiFMA.io.blockResult.valid) { for (i <- 0 until peCount) { - currentRowReg((colIndex.value * peCount.U + i.U) % n.U) := multiFMA.io.blockResult.bits(i) + curRowReg((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) := multiFMA.io.blockResult.bits(i) } stateReg := state.update } @@ -314,11 +331,11 @@ class GEMMFMASingle( is(state.update) { multiFMA.io.reset := true.B - io.currentRow.valid := false.B + multiFMA.io.blockResult.ready := false.B when(colIndex.inc()) { - io.currentRow.valid := true.B - io.currentRow.bits.index := rowIndex.value - io.currentRow.bits.value := currentRowReg + io.curRow.valid := true.B + io.curRow.bits.index := rowIndex.value + io.curRow.bits.value := curRowReg when(rowIndex.inc()) { stateReg := state.done }.otherwise { @@ -330,6 +347,8 @@ class GEMMFMASingle( } is(state.done) { + multiFMA.io.reset := true.B + multiFMA.io.blockResult.ready := false.B io.done := true.B readyReg := true.B stateReg := state.idle @@ -349,16 +368,16 @@ class GEMMSingleQueue( extends Module with DebugLog { val io = IO(new Bundle { - val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) // 矩阵A - val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) // 矩阵B + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) val flush = Input(Bool()) - val currentRow = Decoupled(new currentRowIndex(m, n)) + val curRow = Decoupled(new curRowIndex(m, n)) val done = Output(Bool()) }) - val currentBuffer = Module( + val curBuffer = Module( new Queue( - new currentRowIndex(m, n), + new curRowIndex(m, n), entries = bufferSize, pipe = true, flow = false, @@ -369,9 +388,9 @@ class GEMMSingleQueue( val gemm = Module(new GEMMFMASingle(m, k, n, peCount, gemmType)) gemm.io.matrixA <> io.matrixA gemm.io.matrixB <> io.matrixB - currentBuffer.io.flush.get := io.flush - currentBuffer.io.enq <> gemm.io.currentRow - io.currentRow <> currentBuffer.io.deq + curBuffer.io.flush.get := io.flush + curBuffer.io.enq <> gemm.io.curRow + io.curRow <> curBuffer.io.deq io.done := gemm.io.done } diff --git a/src/main/scala/kernel/alu/OutValue.scala b/src/main/scala/kernel/alu/OutValue.scala index 36547ab..df8e86c 100644 --- a/src/main/scala/kernel/alu/OutValue.scala +++ b/src/main/scala/kernel/alu/OutValue.scala @@ -85,18 +85,18 @@ class OutValueSingle( extends Module with DebugLog { val io = IO(new Bundle { - val currentScores = Flipped(Decoupled(new currentRowIndex(m, m))) + val curScores = Flipped(Decoupled(new curRowIndex(m, m))) val Value = Flipped(Decoupled(Vec(m, Vec(n, UInt(config.inputWidth.W))))) - val currentAttnOut = Decoupled(new currentRowIndex(m, n)) + val curAttnOut = Decoupled(new curRowIndex(m, n)) val done = Output(Bool()) }) - val dataValid = io.currentScores.valid && io.Value.valid + val dataValid = io.curScores.valid && io.Value.valid - io.currentScores.ready := true.B + io.curScores.ready := true.B io.Value.ready := true.B - io.currentAttnOut.valid := false.B - io.currentAttnOut.bits := DontCare + io.curAttnOut.valid := false.B + io.curAttnOut.bits := DontCare io.done := false.B val ValueReg = Reg(Vec(m, Vec(n, UInt(config.inputWidth.W)))) @@ -107,8 +107,8 @@ class OutValueSingle( val rowIndex = Counter(m) val colIndex = Counter(n / peCount) - multiFMA.io.matrixA_row.valid := io.currentScores.valid - multiFMA.io.matrixA_row.bits := io.currentScores.bits.value + multiFMA.io.matrixA_row.valid := io.curScores.valid + multiFMA.io.matrixA_row.bits := io.curScores.bits.value multiFMA.io.matrixB_cols.valid := io.Value.valid multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(m) { j => @@ -120,7 +120,7 @@ class OutValueSingle( multiFMA.io.reset := true.B multiFMA.io.blockResult.ready := false.B - val currentRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) + val curRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) object state extends ChiselEnum { val idle, compute, update, load, done = Value @@ -136,12 +136,12 @@ class OutValueSingle( } } is(state.compute) { - io.currentScores.ready := false.B + io.curScores.ready := false.B multiFMA.io.reset := false.B multiFMA.io.blockResult.ready := true.B when(multiFMA.io.blockResult.valid) { for (i <- 0 until peCount) { - currentRowReg(colIndex.value * peCount.U + i.U) := multiFMA.io.blockResult.bits(i) + curRowReg(colIndex.value * peCount.U + i.U) := multiFMA.io.blockResult.bits(i) } stateReg := state.update } @@ -149,11 +149,11 @@ class OutValueSingle( is(state.update) { multiFMA.io.reset := true.B multiFMA.io.blockResult.ready := false.B - io.currentAttnOut.valid := false.B + io.curAttnOut.valid := false.B when(colIndex.inc()) { - io.currentAttnOut.valid := true.B - io.currentAttnOut.bits.index := rowIndex.value - io.currentAttnOut.bits.value := currentRowReg + io.curAttnOut.valid := true.B + io.curAttnOut.bits.index := rowIndex.value + io.curAttnOut.bits.value := curRowReg when(rowIndex.inc()) { stateReg := state.done }.otherwise { @@ -164,13 +164,13 @@ class OutValueSingle( } } is(state.load) { - io.currentScores.ready := true.B + io.curScores.ready := true.B stateReg := state.compute } is(state.done) { io.done := true.B io.Value.ready := true.B - io.currentScores.ready := true.B + io.curScores.ready := true.B stateReg := state.idle } } diff --git a/src/test/scala/kernel/alu/AttnScoresTest.scala b/src/test/scala/kernel/alu/AttnScoresTest.scala index 8f0556a..51c965c 100644 --- a/src/test/scala/kernel/alu/AttnScoresTest.scala +++ b/src/test/scala/kernel/alu/AttnScoresTest.scala @@ -94,7 +94,7 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { val Query = matInit[T](m, n) val Key = matInit[T](m, n) val expectedResults = mmul(Query, Key.transpose) - + println("Query:") printmat(Query) println("Key:") @@ -106,7 +106,7 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { println("Query and Key are ready") dut.io.Query.valid.poke(true.B) dut.io.Key.valid.poke(true.B) - + for { row <- 0 until m col <- 0 until n @@ -126,7 +126,7 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { dut.io.scores.ready.poke(true.B) val precision = 0.001f var invalidcnt = 0 - + for { row <- 0 until m col <- 0 until m @@ -144,6 +144,65 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { else println(s"Verification failed with $invalidcnt errors.") } + private def testQKMulSingle[T: Numeric: ClassTag]( + dut: QKMulSingle + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val peCount = dut.peCount + val gemmType = dut.gemmType + + val Query = matInit[T](m, n) + val Key = matInit[T](m, n) + val expectedResults = mmul(Query, Key.transpose) + + printmat(Query) + printmat(Key) + printmat(expectedResults) + + if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { + println("Query and Key are ready") + dut.io.Query.valid.poke(true.B) + dut.io.Key.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + dut.io.Query.bits(row)(col).poke(toBinaryBigInt(Query(row)(col)).U) + dut.io.Key.bits(row)(col).poke(toBinaryBigInt(Key(row)(col)).U) + } + } else { + dut.io.Query.valid.poke(false.B) + dut.io.Key.valid.poke(false.B) + } + + val precision = 0.001f + var invalidcnt = 0 + + while (!dut.io.done.peekBoolean()) { + if (dut.io.curRowScores.valid.peekBoolean()) { + val currentRowIndex = dut.io.curRowScores.bits.index.peekInt() + println(s"currentRow index: $currentRowIndex") + + for (i <- 0 until m) { + val outBigInt = dut.io.curRowScores.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(currentRowIndex.toInt)(i) + checkResult(out, expected, currentRowIndex.toInt, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + } + dut.clock.step() + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + private def testAttnScores[T: Numeric: ClassTag]( dut: AttnScores )( @@ -160,7 +219,7 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { val Query = mmul(inputToken, weightQ) val Key = mmul(inputToken, weightK) val expectedResults = mmul(Query, Key.transpose) - + print("Query:\n") printmat(Query) print("Key:\n") @@ -169,15 +228,15 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { printmat(expectedResults) if ( - dut.io.inputToken.ready.peekBoolean() && - dut.io.weightQ.ready.peekBoolean() && + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && dut.io.weightK.ready.peekBoolean() ) { println("inputToken, weightQ and weightK are ready") dut.io.inputToken.valid.poke(true.B) dut.io.weightQ.valid.poke(true.B) dut.io.weightK.valid.poke(true.B) - + for { row <- 0 until m col <- 0 until n @@ -200,7 +259,7 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { dut.io.scores.ready.poke(true.B) val precision = 0.001f var invalidcnt = 0 - + for { row <- 0 until m col <- 0 until m @@ -216,21 +275,202 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { if (invalidcnt == 0) println("Verification passed!") else println(s"Verification failed with $invalidcnt errors.") } - // "AttnScoresTotal " should "compute fxp matrix multiplication" in { + + private def testAttnScoresSingle[T: Numeric: ClassTag]( + dut: AttnScoresSingle + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val Query = mmul(inputToken, weightQ) + val Key = mmul(inputToken, weightK) + val expectedResults = mmul(Query, Key.transpose) + + print("Query:\n") + printmat(Query) + print("Key:\n") + printmat(Key) + print("expectedResults:\n") + printmat(expectedResults) + + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + + val precision = 0.001f + var invalidcnt = 0 + + while (!dut.io.done.peekBoolean()) { + if (dut.io.curRowScores.valid.peekBoolean()) { + val currentRowIndex = dut.io.curRowScores.bits.index.peekInt() + println(s"currentRow index: $currentRowIndex") + + for (i <- 0 until m) { + val outBigInt = dut.io.curRowScores.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(currentRowIndex.toInt)(i) + checkResult(out, expected, currentRowIndex.toInt, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + } + dut.clock.step() + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + private def testAttnScoresSingleQueue[T: Numeric: ClassTag]( + dut: AttnScoresSingleQueue + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val Query = mmul(inputToken, weightQ) + val Key = mmul(inputToken, weightK) + val expectedResults = mmul(Query, Key.transpose) + + print("Query:\n") + printmat(Query) + print("Key:\n") + printmat(Key) + print("expectedResults:\n") + printmat(expectedResults) + dut.io.flush.poke(true.B) + dut.clock.step(1) + dut.io.flush.poke(false.B) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + + val precision = 0.001f + var invalidcnt = 0 + + while (!dut.io.done.peekBoolean()) { + if (dut.io.curRowScores.valid.peekBoolean()) { + val currentRowIndex = dut.io.curRowScores.bits.index.peekInt() + println(s"currentRow index: $currentRowIndex") + + for (i <- 0 until m) { + val outBigInt = dut.io.curRowScores.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(currentRowIndex.toInt)(i) + checkResult(out, expected, currentRowIndex.toInt, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + } + dut.clock.step() + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + + // "AttnScoresSingleQueue " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new AttnScoresTotal(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) + // test(new AttnScoresSingleQueue(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testAttnScoresTotal[Int](dut) + // testAttnScoresSingleQueue[Int](dut) // } // } + "AttnScoresSingleQueue " should "compute fxp matrix multiplication" in { + implicit val config: DataWidthConfig = FxpConfig + test(new AttnScoresSingleQueue(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + testAttnScoresSingleQueue[Int](dut) + } + } + + // "AttnScoresSingle " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new AttnScoresSingle(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testAttnScoresSingle[Int](dut) + // } + // } + + // "AttnScoresSingle " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new AttnScoresSingle(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testAttnScoresSingle[Float](dut) + // } + // } // "QKMul " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new QKMul(m = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) + // test(new QKMul(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => // testQKMul[Int](dut) // } // } + // "QKMulSingle " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKMulSingle(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKMulSingle[Int](dut) + // } + // } // "AttnScores " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig // test(new AttnScores(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) @@ -263,13 +503,13 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { // } // } - "QKGen " should "compute fxp matrix multiplication" in { - implicit val config: DataWidthConfig = FxpConfig - test(new QKGen(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) - .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - testQKGen[Int](dut) - } - } + // "QKGen " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKGen(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKGen[Int](dut) + // } + // } // "QKGenWithReg " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig diff --git a/src/test/scala/kernel/alu/GemmFMATest.scala b/src/test/scala/kernel/alu/GemmFMATest.scala index 2700057..94b8f60 100644 --- a/src/test/scala/kernel/alu/GemmFMATest.scala +++ b/src/test/scala/kernel/alu/GemmFMATest.scala @@ -2,6 +2,7 @@ package kernel.alu import chisel3._ import chiseltest._ +import chisel3.util.DecoupledIO import org.scalatest.freespec.AnyFreeSpec import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.ParallelTestExecution @@ -12,6 +13,19 @@ import Utils._ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestExecution { + // private trait FMADut { + // def k: Int + // def peCount: Int + // def gemmType: GEMMDataType.Type + // def io: Bundle { + // val reset: Bool + // val matrixA_row: DecoupledIO[Vec[UInt]] + // val matrixB_cols: DecoupledIO[Vec[Vec[UInt]]] + // val blockResult: DecoupledIO[Vec[UInt]] + // } + // def clock: Clock + // } + private def testMultiFMA[T: Numeric: ClassTag]( dut: MultiFMA )( @@ -24,7 +38,66 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val matrixA_row = matInit[T](1, k) val matrixB_cols = matInit[T](k, peCount) val expectedResults = mmul(matrixA_row, matrixB_cols) - + + printmat(matrixA_row) + printmat(matrixB_cols) + printmat(expectedResults) + + dut.io.reset.poke(true.B) + dut.clock.step(1) + dut.io.reset.poke(false.B) + + if (dut.io.matrixA_row.ready.peekBoolean() && dut.io.matrixB_cols.ready.peekBoolean()) { + println("matrixA_row and matrixB_cols are ready") + dut.io.matrixA_row.valid.poke(true.B) + dut.io.matrixB_cols.valid.poke(true.B) + + for { + i <- matrixA_row(0).indices + j <- 0 until peCount + } { + dut.io.matrixA_row.bits(i).poke(toBinaryBigInt(matrixA_row(0)(i)).U) + dut.io.matrixB_cols.bits(i)(j).poke(toBinaryBigInt(matrixB_cols(i)(j)).U) + } + } else { + dut.io.matrixA_row.valid.poke(false.B) + dut.io.matrixB_cols.valid.poke(false.B) + } + + while (!dut.io.blockResult.valid.peekBoolean()) { + dut.clock.step() + } + + dut.io.blockResult.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + + for (i <- 0 until peCount) { + val outBigInt = dut.io.blockResult.bits(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = expectedResults(0)(i) + checkResult(out, expected, 0, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + + if (invalidcnt == 0) println("Verification passed!") + else println(s"Verification failed with $invalidcnt errors.") + } + private def testMultiFMA_v2[T: Numeric: ClassTag]( + dut: MultiFMA_v2 + )( + implicit config: DataWidthConfig + ): Unit = { + val k = dut.k + val peCount = dut.peCount + val gemmType = dut.gemmType + + val matrixA_row = matInit[T](1, k) + val matrixB_cols = matInit[T](k, peCount) + val expectedResults = mmul(matrixA_row, matrixB_cols) + printmat(matrixA_row) printmat(matrixB_cols) printmat(expectedResults) @@ -37,7 +110,7 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe println("matrixA_row and matrixB_cols are ready") dut.io.matrixA_row.valid.poke(true.B) dut.io.matrixB_cols.valid.poke(true.B) - + for { i <- matrixA_row(0).indices j <- 0 until peCount @@ -86,7 +159,7 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val matrixA = matInit[T](m, k) val matrixB = matInit[T](k, n) val expectedResults = mmul(matrixA, matrixB) - + printmat(matrixA) printmat(matrixB) printmat(expectedResults) @@ -95,7 +168,7 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe println("matrixA and matrixB are ready") dut.io.matrixA.valid.poke(true.B) dut.io.matrixB.valid.poke(true.B) - + for { row <- 0 until m col <- 0 until n @@ -148,7 +221,7 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val matrixA = matInit[T](m, k) val matrixB = matInit[T](k, n) val expectedResults = mmul(matrixA, matrixB) - + printmat(matrixA) printmat(matrixB) printmat(expectedResults) @@ -157,7 +230,7 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe println("matrixA and matrixB are ready") dut.io.matrixA.valid.poke(true.B) dut.io.matrixB.valid.poke(true.B) - + for { row <- 0 until m col <- 0 until n @@ -175,16 +248,16 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe var invalidcnt = 0 while (!dut.io.done.peekBoolean()) { - if (dut.io.currentRow.valid.peekBoolean()) { - val currentRowIndex = dut.io.currentRow.bits.index.peekInt() - println(s"currentRow index: $currentRowIndex") - + if (dut.io.curRow.valid.peekBoolean()) { + val curRowIndex = dut.io.curRow.bits.index.peekInt() + println(s"curRow index: $curRowIndex") + for (i <- 0 until n) { - val outBigInt = dut.io.currentRow.bits.value(i).peekInt() + val outBigInt = dut.io.curRow.bits.value(i).peekInt() val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(currentRowIndex.toInt)(i) - - checkResult(out, expected, currentRowIndex.toInt, i, precision) match { + val expected = expectedResults(curRowIndex.toInt)(i) + + checkResult(out, expected, curRowIndex.toInt, i, precision) match { case Some(_) => invalidcnt += 1 case None => // right } @@ -210,16 +283,20 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val matrixA = matInit[T](m, k) val matrixB = matInit[T](k, n) val expectedResults = mmul(matrixA, matrixB) - + printmat(matrixA) printmat(matrixB) printmat(expectedResults) + dut.io.flush.poke(true.B) + dut.clock.step(1) + dut.io.flush.poke(false.B) + if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { println("matrixA and matrixB are ready") dut.io.matrixA.valid.poke(true.B) dut.io.matrixB.valid.poke(true.B) - + for { row <- 0 until m col <- 0 until n @@ -233,21 +310,21 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe dut.io.matrixB.valid.poke(false.B) } - dut.io.currentRow.ready.poke(true.B) + dut.io.curRow.ready.poke(true.B) val precision = 0.001f var invalidcnt = 0 while (!dut.io.done.peekBoolean()) { - if (dut.io.currentRow.valid.peekBoolean()) { - val currentRowIndex = dut.io.currentRow.bits.index.peekInt() - println(s"currentRow index: $currentRowIndex") - + if (dut.io.curRow.valid.peekBoolean()) { + val curRowIndex = dut.io.curRow.bits.index.peekInt() + println(s"curRow index: $curRowIndex") + for (i <- 0 until n) { - val outBigInt = dut.io.currentRow.bits.value(i).peekInt() + val outBigInt = dut.io.curRow.bits.value(i).peekInt() val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(currentRowIndex.toInt)(i) - - checkResult(out, expected, currentRowIndex.toInt, i, precision) match { + val expected = expectedResults(curRowIndex.toInt)(i) + + checkResult(out, expected, curRowIndex.toInt, i, precision) match { case Some(_) => invalidcnt += 1 case None => // right } @@ -260,14 +337,13 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe else println(s"Verification failed with $invalidcnt errors.") } - - "GEMMSingleQueue " should "compute fxp matrix multiplication" in { - implicit val config: DataWidthConfig = FxpConfig - test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) - .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - testGEMMSingleQueue[Int](dut) - } - } + // "GEMMSingleQueue " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testGEMMSingleQueue[Int](dut) + // } + // } // "GEMMSingleQueue " should "compute fp32 matrix multiplication" in { // implicit val config: DataWidthConfig = Fp32Config @@ -324,4 +400,18 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe // testMultiFMA[Float](dut) // } // } + // "MultiFMA_v2 " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new MultiFMA_v2(k = 4, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testMultiFMA_v2[Float](dut) + // } + // } + // "MultiFMA_v2 " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new MultiFMA_v2(k = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testMultiFMA_v2[Int](dut) + // } + // } } diff --git a/src/test/scala/kernel/alu/OutValueTest.scala b/src/test/scala/kernel/alu/OutValueTest.scala index a48a381..733ba02 100644 --- a/src/test/scala/kernel/alu/OutValueTest.scala +++ b/src/test/scala/kernel/alu/OutValueTest.scala @@ -106,31 +106,31 @@ class OutValueTest extends AnyFlatSpec with ChiselScalatestTester with ParallelT } for (index <- 0 until m) { - if (dut.io.currentScores.ready.peekBoolean()) { - println(s"currentScores index: $index is ready") - dut.io.currentScores.valid.poke(true.B) + if (dut.io.curScores.ready.peekBoolean()) { + println(s"curScores index: $index is ready") + dut.io.curScores.valid.poke(true.B) for (i <- 0 until m) { - dut.io.currentScores.bits.value(i).poke(toBinaryBigInt(AttnWeights(index)(i)).U) + dut.io.curScores.bits.value(i).poke(toBinaryBigInt(AttnWeights(index)(i)).U) } } else { - dut.io.currentScores.valid.poke(false.B) + dut.io.curScores.valid.poke(false.B) } - dut.io.currentAttnOut.ready.poke(false.B) - while (!dut.io.currentAttnOut.valid.peekBoolean()) { + dut.io.curAttnOut.ready.poke(false.B) + while (!dut.io.curAttnOut.valid.peekBoolean()) { dut.clock.step() } - dut.io.currentAttnOut.ready.poke(true.B) - val currentRowIndex = dut.io.currentAttnOut.bits.index.peekInt() + dut.io.curAttnOut.ready.poke(true.B) + val curRowIndex = dut.io.curAttnOut.bits.index.peekInt() for (i <- 0 until n) { - val outBigInt = dut.io.currentAttnOut.bits.value(i).peekInt() + val outBigInt = dut.io.curAttnOut.bits.value(i).peekInt() val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(currentRowIndex.toInt)(i) + val expected = expectedResults(curRowIndex.toInt)(i) - checkResult(out, expected, currentRowIndex.toInt, i, precision) match { + checkResult(out, expected, curRowIndex.toInt, i, precision) match { case Some(_) => invalidcnt += 1 case None => // right } From b1e675656bb910e5f49a21d85c550be025359a4b Mon Sep 17 00:00:00 2001 From: pyfirstcsh Date: Tue, 14 Jan 2025 23:36:19 +0800 Subject: [PATCH 07/10] add fork test --- src/main/scala/kernel/alu/AttnScores.scala | 214 ++-- src/main/scala/kernel/alu/Gemm.scala | 2 +- src/main/scala/kernel/alu/GemmFMA.scala | 48 +- src/main/scala/kernel/alu/OutValue.scala | 21 +- .../scala/kernel/alu/AttnScoresTest.scala | 945 +++++++++++------- src/test/scala/kernel/alu/GemmFMATest.scala | 347 ++++--- src/test/scala/kernel/alu/OutValueTest.scala | 233 +++-- 7 files changed, 1048 insertions(+), 762 deletions(-) diff --git a/src/main/scala/kernel/alu/AttnScores.scala b/src/main/scala/kernel/alu/AttnScores.scala index b5dd89a..60df6d3 100644 --- a/src/main/scala/kernel/alu/AttnScores.scala +++ b/src/main/scala/kernel/alu/AttnScores.scala @@ -34,10 +34,13 @@ class QKGenWithReg( val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - val readyReg = RegInit(true.B) - io.inputToken.ready := readyReg - io.weightQ.ready := readyReg - io.weightK.ready := readyReg + // val readyReg = RegInit(true.B) + // io.inputToken.ready := readyReg + // io.weightQ.ready := readyReg + // io.weightK.ready := readyReg + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B io.Key.valid := false.B io.Key.bits := DontCare io.Query.valid := false.B @@ -69,7 +72,10 @@ class QKGenWithReg( switch(stateReg) { is(state.idle) { when(dataValid) { - readyReg := false.B + // readyReg := false.B + io.inputToken.ready := false.B + io.weightQ.ready := false.B + io.weightK.ready := false.B stateReg := state.gen } } @@ -85,7 +91,10 @@ class QKGenWithReg( is(state.done) { qGen.io.results.ready := false.B kGen.io.results.ready := false.B - readyReg := true.B + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B + // readyReg := true.B io.Query.valid := true.B io.Key.valid := true.B io.Query.bits := Qreg @@ -121,10 +130,13 @@ class QKGen( val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - val readyReg = RegInit(true.B) - io.inputToken.ready := readyReg - io.weightQ.ready := readyReg - io.weightK.ready := readyReg + // val readyReg = RegInit(true.B) + // io.inputToken.ready := readyReg + // io.weightQ.ready := readyReg + // io.weightK.ready := readyReg + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B io.Key.valid := false.B io.Key.bits := DontCare io.Query.valid := false.B @@ -153,7 +165,10 @@ class QKGen( switch(stateReg) { is(state.idle) { when(dataValid) { - readyReg := false.B + // readyReg := false.B + io.inputToken.ready := false.B + io.weightQ.ready := false.B + io.weightK.ready := false.B stateReg := state.gen } } @@ -167,7 +182,10 @@ class QKGen( is(state.done) { qGen.io.results.ready := false.B kGen.io.results.ready := false.B - readyReg := true.B + // readyReg := true.B + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B io.Query.valid := true.B io.Key.valid := true.B io.Query.bits := qGen.io.results.bits @@ -198,9 +216,11 @@ class QKMulWithReg( val dataValid = io.Query.valid && io.Key.valid - val readyReg = RegInit(true.B) - io.Query.ready := readyReg - io.Key.ready := readyReg + // val readyReg = RegInit(true.B) + // io.Query.ready := readyReg + // io.Key.ready := readyReg + io.Query.ready := true.B + io.Key.ready := true.B io.scores.valid := false.B io.scores.bits := DontCare @@ -221,7 +241,8 @@ class QKMulWithReg( switch(stateReg) { is(state.idle) { when(dataValid) { - readyReg := false.B + // readyReg := false.B + io.Query.ready := false.B stateReg := state.mul } } @@ -234,7 +255,8 @@ class QKMulWithReg( } is(state.done) { QK_TMul.io.results.ready := false.B - readyReg := true.B + // readyReg := true.B + io.Query.ready := true.B io.scores.valid := true.B io.scores.bits := scoresReg stateReg := state.idle @@ -263,9 +285,11 @@ class QKMul( val dataValid = io.Query.valid && io.Key.valid - val readyReg = RegInit(true.B) - io.Query.ready := readyReg - io.Key.ready := readyReg + // val readyReg = RegInit(true.B) + // io.Query.ready := readyReg + // io.Key.ready := readyReg + io.Query.ready := true.B + io.Key.ready := true.B io.scores.valid := false.B io.scores.bits := DontCare @@ -285,7 +309,8 @@ class QKMul( switch(stateReg) { is(state.idle) { when(dataValid) { - readyReg := false.B + // readyReg := false.B + io.Query.ready := false.B stateReg := state.mul } } @@ -297,7 +322,8 @@ class QKMul( } is(state.done) { QK_TMul.io.results.ready := false.B - readyReg := true.B + // readyReg := true.B + io.Query.ready := true.B io.scores.valid := true.B io.scores.bits := QK_TMul.io.results.bits stateReg := state.idle @@ -406,11 +432,14 @@ class AttnScores( val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - val readyReg = RegInit(true.B) - io.inputToken.ready := readyReg - io.weightQ.ready := readyReg - io.weightK.ready := readyReg + // val readyReg = RegInit(true.B) + // io.inputToken.ready := readyReg + // io.weightQ.ready := readyReg + // io.weightK.ready := readyReg + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B io.scores.valid := false.B io.scores.bits := DontCare @@ -445,7 +474,9 @@ class AttnScores( switch(stateReg) { is(state.idle) { when(dataValid) { - readyReg := false.B + // readyReg := false.B + io.inputToken.ready := false.B + io.weightQ.ready := false.B stateReg := state.gen } } @@ -467,7 +498,9 @@ class AttnScores( } is(state.done) { QKMul.io.scores.ready := false.B - readyReg := true.B + // readyReg := true.B + io.inputToken.ready := true.B + io.weightQ.ready := true.B io.scores.valid := true.B // io.scores.bits := scoresReg io.scores.bits := QKMul.io.scores.bits @@ -502,10 +535,13 @@ class AttnScoresSingle( val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - val readyReg = RegInit(true.B) - io.inputToken.ready := readyReg - io.weightQ.ready := readyReg - io.weightK.ready := readyReg + // val readyReg = RegInit(true.B) + // io.inputToken.ready := readyReg + // io.weightQ.ready := readyReg + // io.weightK.ready := readyReg + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B io.curRowScores.valid := false.B io.curRowScores.bits := DontCare io.done := false.B @@ -540,7 +576,10 @@ class AttnScoresSingle( switch(stateReg) { is(state.idle) { when(dataValid) { - readyReg := false.B + // readyReg := false.B + io.inputToken.ready := false.B + io.weightQ.ready := false.B + io.weightK.ready := false.B stateReg := state.gen } } @@ -555,14 +594,8 @@ class AttnScoresSingle( QKGen.io.Query.ready := false.B QKGen.io.Key.ready := false.B QKMul.io.curRow.ready := true.B - // printf(p"QKGen.io.Query.bits: ${QKGen.io.Query.bits}\n") - // printf(p"QKGen.io.Key.bits: ${QKGen.io.Key.bits}\n") - // printf(p"QKMul.io.matrixA.bits: ${QKMul.io.matrixA.bits}\n") - // printf(p"QKMul.io.matrixB.bits: ${QKMul.io.matrixB.bits}\n") when(QKMul.io.curRow.valid) { curRowIndexReg := QKMul.io.curRow.bits - printf(p"QKMul.io.curRow.bits: ${QKMul.io.curRow.bits}\n") - printf(p"curRowIndexReg: ${curRowIndexReg}\n") stateReg := state.update } } @@ -570,9 +603,6 @@ class AttnScoresSingle( QKMul.io.curRow.ready := false.B io.curRowScores.valid := true.B io.curRowScores.bits := curRowIndexReg - printf(p"Update curRowIndexReg: ${curRowIndexReg}\n") - printf(p"Update io.curRowScores.bits: ${io.curRowScores.bits}\n") - printf(p"Update io.curRowScores.valid: ${io.curRowScores.valid}\n") when(QKMul.io.done) { stateReg := state.done }.otherwise { @@ -580,14 +610,17 @@ class AttnScoresSingle( } } is(state.done) { - readyReg := true.B + // readyReg := true.B + io.inputToken.ready := true.B + io.weightQ.ready := true.B + io.weightK.ready := true.B io.done := true.B stateReg := state.idle } } } -// AttnScoresSingleQueue: use QKGen to get Q and K, get scores by row +// AttnScoresSingleQueue: use Queue to store scores // input: inputToken: m * k // input: weightQ: k * n // input: weightK: k * n @@ -614,42 +647,9 @@ class AttnScoresSingleQueue( val done = Output(Bool()) }) - val dataValid = io.inputToken.valid && io.weightQ.valid && io.weightK.valid - - val readyReg = RegInit(true.B) - io.inputToken.ready := readyReg - io.weightQ.ready := readyReg - io.weightK.ready := readyReg - io.curRowScores.valid := false.B - io.curRowScores.bits := DontCare - io.done := false.B - - // val QKGen = Module(new QKGenWithReg(m, k, n, peCount, gemmType)) - val QKGen = Module(new QKGen(m, k, n, peCount, gemmType)) - - QKGen.io.inputToken.valid := io.inputToken.valid - QKGen.io.inputToken.bits := io.inputToken.bits - QKGen.io.weightQ.valid := io.weightQ.valid - QKGen.io.weightQ.bits := io.weightQ.bits - QKGen.io.weightK.valid := io.weightK.valid - QKGen.io.weightK.bits := io.weightK.bits - - QKGen.io.Query.ready := false.B - QKGen.io.Key.ready := false.B - - val curRowIndexReg = Reg(new curRowIndex(m, m)) - - val QKMul = Module(new GEMMFMASingle(m, n, m, peCount, gemmType)) - - QKMul.io.matrixA.valid := QKGen.io.Query.valid - QKMul.io.matrixA.bits := QKGen.io.Query.bits - QKMul.io.matrixB.valid := QKGen.io.Key.valid - QKMul.io.matrixB.bits := VecInit(QKGen.io.Key.bits.transpose.map(VecInit(_))) - QKMul.io.curRow.ready := false.B - val curBuffer = Module( new Queue( - new curRowIndex(m, n), + new curRowIndex(m, m), entries = bufferSize, pipe = true, flow = false, @@ -657,54 +657,18 @@ class AttnScoresSingleQueue( hasFlush = true ) ) + val doneReg = RegInit(false.B) + + val attnScores = Module(new AttnScoresSingle(m, k, n, peCount, gemmType)) + attnScores.io.inputToken <> io.inputToken + attnScores.io.weightQ <> io.weightQ + attnScores.io.weightK <> io.weightK + curBuffer.io.flush.get := io.flush - curBuffer.io.enq <> QKMul.io.curRow - io.curRowScores <> curBuffer.io.deq + curBuffer.io.enq <> attnScores.io.curRowScores + doneReg := attnScores.io.done - object state extends ChiselEnum { - val idle, gen, mul, update, done = Value - } - val stateReg = RegInit(state.idle) + io.curRowScores <> curBuffer.io.deq + io.done := doneReg - switch(stateReg) { - is(state.idle) { - when(dataValid) { - readyReg := false.B - stateReg := state.gen - } - } - is(state.gen) { - QKGen.io.Query.ready := true.B - QKGen.io.Key.ready := true.B - when(QKGen.io.Query.valid && QKGen.io.Key.valid) { - stateReg := state.mul - } - } - is(state.mul) { - QKGen.io.Query.ready := false.B - QKGen.io.Key.ready := false.B - QKMul.io.curRow.ready := true.B - // printf(p"QKGen.io.Query.bits: ${QKGen.io.Query.bits}\n") - // printf(p"QKGen.io.Key.bits: ${QKGen.io.Key.bits}\n") - when(QKMul.io.curRow.valid) { - // curRowIndexReg := QKMul.io.curRow.bits - stateReg := state.update - } - } - is(state.update) { - QKMul.io.curRow.ready := false.B - io.curRowScores.valid := true.B - // io.curRowScores.bits := curRowIndexReg - when(QKMul.io.done) { - stateReg := state.done - }.otherwise { - stateReg := state.mul - } - } - is(state.done) { - readyReg := true.B - io.done := true.B - stateReg := state.idle - } - } } diff --git a/src/main/scala/kernel/alu/Gemm.scala b/src/main/scala/kernel/alu/Gemm.scala index c7497fb..8225b70 100644 --- a/src/main/scala/kernel/alu/Gemm.scala +++ b/src/main/scala/kernel/alu/Gemm.scala @@ -8,7 +8,7 @@ import fputil.FPAdd import hardfloat._ trait GEMMAccuracyConfig { - val I: Int = 8 + val I: Int = 10 val F: Int = 0 } diff --git a/src/main/scala/kernel/alu/GemmFMA.scala b/src/main/scala/kernel/alu/GemmFMA.scala index 4335258..caf0313 100644 --- a/src/main/scala/kernel/alu/GemmFMA.scala +++ b/src/main/scala/kernel/alu/GemmFMA.scala @@ -180,9 +180,17 @@ class GEMMFMATotal( val dataValid = io.matrixA.valid && io.matrixB.valid - val readyReg = RegInit(true.B) - io.matrixA.ready := readyReg - io.matrixB.ready := readyReg + //TODO:ERR use readyReg and resValid case deadlock + // cases "Exception in thread "chiseltest_thread_2" java.lang.RuntimeException: Deadlock!" Error + // when test fork() and join() in chiseltest + + // val readyReg = RegInit(true.B) + // val resValid = RegInit(false.B) + // io.matrixA.ready := readyReg + // io.matrixB.ready := readyReg + // io.results.valid := resValid + io.matrixA.ready := true.B + io.matrixB.ready := true.B io.results.valid := false.B io.results.bits := DontCare @@ -214,7 +222,9 @@ class GEMMFMATotal( switch(stateReg) { is(state.idle) { when(dataValid) { - readyReg := false.B + // readyReg := false.B + io.matrixA.ready := false.B + io.matrixB.ready := false.B stateReg := state.compute } } @@ -243,10 +253,12 @@ class GEMMFMATotal( }.otherwise { stateReg := state.compute } - } is(state.done) { - readyReg := true.B + // resValid := true.B + // readyReg := true.B + io.matrixA.ready := true.B + io.matrixB.ready := true.B io.results.valid := true.B io.results.bits := resultsReg stateReg := state.idle @@ -278,9 +290,11 @@ class GEMMFMASingle( }) val dataValid = io.matrixA.valid && io.matrixB.valid - val readyReg = RegInit(true.B) - io.matrixA.ready := readyReg - io.matrixB.ready := readyReg + // val readyReg = RegInit(true.B) + // io.matrixA.ready := readyReg + // io.matrixB.ready := readyReg + io.matrixA.ready := true.B + io.matrixB.ready := true.B io.curRow.valid := false.B io.curRow.bits := DontCare io.done := false.B @@ -313,7 +327,9 @@ class GEMMFMASingle( switch(stateReg) { is(state.idle) { when(dataValid) { - readyReg := false.B + // readyReg := false.B + io.matrixA.ready := false.B + io.matrixB.ready := false.B stateReg := state.compute } } @@ -349,8 +365,10 @@ class GEMMFMASingle( is(state.done) { multiFMA.io.reset := true.B multiFMA.io.blockResult.ready := false.B + io.matrixA.ready := true.B + io.matrixB.ready := true.B io.done := true.B - readyReg := true.B + // readyReg := true.B stateReg := state.idle } } @@ -368,8 +386,8 @@ class GEMMSingleQueue( extends Module with DebugLog { val io = IO(new Bundle { - val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) - val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) + val matrixA = Flipped(Decoupled(Vec(m, Vec(k, UInt(config.inputWidth.W))))) + val matrixB = Flipped(Decoupled(Vec(k, Vec(n, UInt(config.inputWidth.W))))) val flush = Input(Bool()) val curRow = Decoupled(new curRowIndex(m, n)) val done = Output(Bool()) @@ -385,12 +403,14 @@ class GEMMSingleQueue( hasFlush = true ) ) + val doneReg = RegInit(false.B) val gemm = Module(new GEMMFMASingle(m, k, n, peCount, gemmType)) gemm.io.matrixA <> io.matrixA gemm.io.matrixB <> io.matrixB curBuffer.io.flush.get := io.flush curBuffer.io.enq <> gemm.io.curRow + doneReg := gemm.io.done io.curRow <> curBuffer.io.deq - io.done := gemm.io.done + io.done := doneReg } diff --git a/src/main/scala/kernel/alu/OutValue.scala b/src/main/scala/kernel/alu/OutValue.scala index df8e86c..d8d0bf8 100644 --- a/src/main/scala/kernel/alu/OutValue.scala +++ b/src/main/scala/kernel/alu/OutValue.scala @@ -28,9 +28,11 @@ class OutValue( val dataValid = io.Scores.valid && io.Value.valid - val readyReg = RegInit(true.B) - io.Scores.ready := readyReg - io.Value.ready := readyReg + // val readyReg = RegInit(true.B) + // io.Scores.ready := readyReg + // io.Value.ready := readyReg + io.Scores.ready := true.B + io.Value.ready := true.B io.AttnOut.valid := false.B io.AttnOut.bits := DontCare @@ -50,7 +52,9 @@ class OutValue( switch(stateReg) { is(state.idle) { when(dataValid) { - readyReg := false.B + // readyReg := false.B + io.Scores.ready := false.B + io.Value.ready := false.B stateReg := state.compute } } @@ -62,7 +66,9 @@ class OutValue( } is(state.done) { ValueMul.io.results.ready := false.B - readyReg := true.B + // readyReg := true.B + io.Scores.ready := true.B + io.Value.ready := true.B io.AttnOut.valid := true.B io.AttnOut.bits := ValueMul.io.results.bits stateReg := state.idle @@ -113,7 +119,7 @@ class OutValueSingle( multiFMA.io.matrixB_cols.valid := io.Value.valid multiFMA.io.matrixB_cols.bits := VecInit(Seq.tabulate(m) { j => VecInit(Seq.tabulate(peCount) { i => - ValueReg(j)(((colIndex.value << log2Ceil(peCount).U) + i.U)(log2Ceil(n)-1, 0)) + ValueReg(j)((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) }) }) //m * peCount size block of Value @@ -132,6 +138,7 @@ class OutValueSingle( is(state.idle) { when(dataValid) { io.Value.ready := false.B + io.curScores.ready := false.B stateReg := state.compute } } @@ -141,7 +148,7 @@ class OutValueSingle( multiFMA.io.blockResult.ready := true.B when(multiFMA.io.blockResult.valid) { for (i <- 0 until peCount) { - curRowReg(colIndex.value * peCount.U + i.U) := multiFMA.io.blockResult.bits(i) + curRowReg(colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0) := multiFMA.io.blockResult.bits(i) } stateReg := state.update } diff --git a/src/test/scala/kernel/alu/AttnScoresTest.scala b/src/test/scala/kernel/alu/AttnScoresTest.scala index 51c965c..83eee57 100644 --- a/src/test/scala/kernel/alu/AttnScoresTest.scala +++ b/src/test/scala/kernel/alu/AttnScoresTest.scala @@ -21,65 +21,174 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { val k = dut.k val n = dut.n val gemmType = dut.gemmType - - val inputToken = matInit[T](m, k) - val weightQ = matInit[T](k, n) - val weightK = matInit[T](k, n) - val Query = mmul(inputToken, weightQ) - val Key = mmul(inputToken, weightK) - - printmat(Query) - printmat(Key) - - if ( - dut.io.inputToken.ready.peekBoolean() && - dut.io.weightQ.ready.peekBoolean() && - dut.io.weightK.ready.peekBoolean() - ) { - println("inputToken, weightQ and weightK are ready") - dut.io.inputToken.valid.poke(true.B) - dut.io.weightQ.valid.poke(true.B) - dut.io.weightK.valid.poke(true.B) - - for { - row <- 0 until m - col <- 0 until n - i <- 0 until k - } { - dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) - dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) - dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) - } - } else { - dut.io.inputToken.valid.poke(false.B) - dut.io.weightQ.valid.poke(false.B) - dut.io.weightK.valid.poke(false.B) - } - - while (!(dut.io.Key.valid.peekBoolean() && dut.io.Query.valid.peekBoolean())) { - dut.clock.step() + val caseNum = 10 + + val testCases = Array.tabulate(caseNum) { i => + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val exQuery = mmul(inputToken, weightQ) + val exKey = mmul(inputToken, weightK) + (inputToken, weightQ, weightK, exQuery, exKey) } - dut.io.Key.ready.poke(true.B) - dut.io.Query.ready.poke(true.B) - - val precision = 0.001f - var invalidcnt = 0 - for { - row <- 0 until m - col <- 0 until n - } { - val outBigInt = dut.io.Query.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = Query(row)(col) - checkResult(out, expected, row, col, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inputToken, weightQ, weightK, _, _) = testCases(cnt) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("test case " + cnt + ": inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + cnt += 1 + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + dut.clock.step() } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, _, exQuery, exKey) = testCases(resCnt) + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.Key.valid.peekBoolean() && dut.io.Query.valid.peekBoolean()) { + dut.io.Key.ready.poke(true.B) + dut.io.Query.ready.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + val outQBigInt = dut.io.Query.bits(row)(col).peekInt() + val outKBigInt = dut.io.Key.bits(row)(col).peekInt() + val outQ = fromBinaryBigInt[T](outQBigInt) + val outK = fromBinaryBigInt[T](outKBigInt) + val expectedQ = exQuery(row)(col) + val expectedK = exKey(row)(col) + checkResult(outQ, expectedQ, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + checkResult(outK, expectedK, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("case " + resCnt + ": Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.Key.ready.poke(false.B) + dut.io.Query.ready.poke(false.B) + } + dut.clock.step() + } + }.join() + } + private def testQKGenWithReg[T: Numeric: ClassTag]( + dut: QKGenWithReg + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val k = dut.k + val n = dut.n + val gemmType = dut.gemmType + val caseNum = 10 + + val testCases = Array.tabulate(caseNum) { i => + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val exQuery = mmul(inputToken, weightQ) + val exKey = mmul(inputToken, weightK) + (inputToken, weightQ, weightK, exQuery, exKey) } - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inputToken, weightQ, weightK, _, _) = testCases(cnt) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("test case " + cnt + ": inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + cnt += 1 + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, _, exQuery, exKey) = testCases(resCnt) + if (dut.io.Key.valid.peekBoolean() && dut.io.Query.valid.peekBoolean()) { + dut.io.Key.ready.poke(true.B) + dut.io.Query.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until n + } { + val outQBigInt = dut.io.Query.bits(row)(col).peekInt() + val outKBigInt = dut.io.Key.bits(row)(col).peekInt() + val outQ = fromBinaryBigInt[T](outQBigInt) + val outK = fromBinaryBigInt[T](outKBigInt) + val expectedQ = exQuery(row)(col) + val expectedK = exKey(row)(col) + checkResult(outQ, expectedQ, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + checkResult(outK, expectedK, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("case " + resCnt + ": Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.Key.ready.poke(false.B) + dut.io.Query.ready.poke(false.B) + } + dut.clock.step() + } + }.join() } private def testQKMul[T: Numeric: ClassTag]( @@ -91,57 +200,137 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { val n = dut.n val gemmType = dut.gemmType - val Query = matInit[T](m, n) - val Key = matInit[T](m, n) - val expectedResults = mmul(Query, Key.transpose) - - println("Query:") - printmat(Query) - println("Key:") - printmat(Key) - println("expectedResults:") - printmat(expectedResults) - - if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { - println("Query and Key are ready") - dut.io.Query.valid.poke(true.B) - dut.io.Key.valid.poke(true.B) - - for { - row <- 0 until m - col <- 0 until n - } { - dut.io.Query.bits(row)(col).poke(toBinaryBigInt(Query(row)(col)).U) - dut.io.Key.bits(row)(col).poke(toBinaryBigInt(Key(row)(col)).U) - } - } else { - dut.io.Query.valid.poke(false.B) - dut.io.Key.valid.poke(false.B) + val caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inQuery = matInit[T](m, n) + val inKey = matInit[T](m, n) + val exScores = mmul(inQuery, inKey.transpose) + (inQuery, inKey, exScores) } - while (!dut.io.scores.valid.peekBoolean()) { - dut.clock.step() + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inQuery, inKey, _) = testCases(cnt) + if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { + println("test case " + cnt + ": Query and Key are ready") + dut.io.Query.valid.poke(true.B) + dut.io.Key.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + dut.io.Query.bits(row)(col).poke(toBinaryBigInt(inQuery(row)(col)).U) + dut.io.Key.bits(row)(col).poke(toBinaryBigInt(inKey(row)(col)).U) + } + cnt += 1 + } else { + dut.io.Query.valid.poke(false.B) + dut.io.Key.valid.poke(false.B) + } + dut.clock.step() + } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exScores) = testCases(resCnt) + if (dut.io.scores.valid.peekBoolean()) { + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until m + } { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("case " + resCnt + ": Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.scores.ready.poke(false.B) + } + dut.clock.step() + } + }.join() + + } + + private def testQKMulWithReg[T: Numeric: ClassTag]( + dut: QKMulWithReg + )( + implicit config: DataWidthConfig + ): Unit = { + val m = dut.m + val n = dut.n + val gemmType = dut.gemmType + + val caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inQuery = matInit[T](m, n) + val inKey = matInit[T](m, n) + val exScores = mmul(inQuery, inKey.transpose) + (inQuery, inKey, exScores) } - dut.io.scores.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - - for { - row <- 0 until m - col <- 0 until m - } { - val outBigInt = dut.io.scores.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - checkResult(out, expected, row, col, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inQuery, inKey, _) = testCases(cnt) + if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { + println("test case " + cnt + ": Query and Key are ready") + dut.io.Query.valid.poke(true.B) + dut.io.Key.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + dut.io.Query.bits(row)(col).poke(toBinaryBigInt(inQuery(row)(col)).U) + dut.io.Key.bits(row)(col).poke(toBinaryBigInt(inKey(row)(col)).U) + } + cnt += 1 + } else { + dut.io.Query.valid.poke(false.B) + dut.io.Key.valid.poke(false.B) + } + dut.clock.step() } - } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exScores) = testCases(resCnt) + if (dut.io.scores.valid.peekBoolean()) { + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until m + } { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("case " + resCnt + ": Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.scores.ready.poke(false.B) + } + dut.clock.step() + } + }.join() - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") } private def testQKMulSingle[T: Numeric: ClassTag]( @@ -154,53 +343,74 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { val peCount = dut.peCount val gemmType = dut.gemmType - val Query = matInit[T](m, n) - val Key = matInit[T](m, n) - val expectedResults = mmul(Query, Key.transpose) - - printmat(Query) - printmat(Key) - printmat(expectedResults) - - if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { - println("Query and Key are ready") - dut.io.Query.valid.poke(true.B) - dut.io.Key.valid.poke(true.B) - for { - row <- 0 until m - col <- 0 until n - } { - dut.io.Query.bits(row)(col).poke(toBinaryBigInt(Query(row)(col)).U) - dut.io.Key.bits(row)(col).poke(toBinaryBigInt(Key(row)(col)).U) - } - } else { - dut.io.Query.valid.poke(false.B) - dut.io.Key.valid.poke(false.B) + val caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inQuery = matInit[T](m, n) + val inKey = matInit[T](m, n) + val exScores = mmul(inQuery, inKey.transpose) + (inQuery, inKey, exScores) } - val precision = 0.001f - var invalidcnt = 0 - - while (!dut.io.done.peekBoolean()) { - if (dut.io.curRowScores.valid.peekBoolean()) { - val currentRowIndex = dut.io.curRowScores.bits.index.peekInt() - println(s"currentRow index: $currentRowIndex") - - for (i <- 0 until m) { - val outBigInt = dut.io.curRowScores.bits.value(i).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(currentRowIndex.toInt)(i) - checkResult(out, expected, currentRowIndex.toInt, i, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inQuery, inKey, _) = testCases(cnt) + if (dut.io.Query.ready.peekBoolean() && dut.io.Key.ready.peekBoolean()) { + println("test case " + cnt + ": Query and Key are ready") + dut.io.Query.valid.poke(true.B) + dut.io.Key.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + dut.io.Query.bits(row)(col).poke(toBinaryBigInt(inQuery(row)(col)).U) + dut.io.Key.bits(row)(col).poke(toBinaryBigInt(inKey(row)(col)).U) } + cnt += 1 + } else { + dut.io.Query.valid.poke(false.B) + dut.io.Key.valid.poke(false.B) } + dut.clock.step() } - dut.clock.step() - } + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exScores) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curRowScores.valid.peekBoolean()) { + dut.io.curRowScores.ready.poke(true.B) + dut.io.curRowScores.ready.poke(true.B) + val curRowIndex = dut.io.curRowScores.bits.index.peekInt() + println(s"curRow index: $curRowIndex") + dut.io.curRowScores.bits.index.expect(rowIdx.U) + for { + col <- 0 until m + } { + val outBigInt = dut.io.curRowScores.bits.value(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(rowIdx)(col) + checkResult(out, expected, rowIdx, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curRowScores.ready.poke(false.B) + } + dut.clock.step() + } + dut.io.done.expect(true.B) + resCnt += 1 + } + }.join() - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") } private def testAttnScores[T: Numeric: ClassTag]( @@ -213,67 +423,76 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { val n = dut.n val gemmType = dut.gemmType - val inputToken = matInit[T](m, k) - val weightQ = matInit[T](k, n) - val weightK = matInit[T](k, n) - val Query = mmul(inputToken, weightQ) - val Key = mmul(inputToken, weightK) - val expectedResults = mmul(Query, Key.transpose) - - print("Query:\n") - printmat(Query) - print("Key:\n") - printmat(Key) - print("expectedResults:\n") - printmat(expectedResults) - - if ( - dut.io.inputToken.ready.peekBoolean() && - dut.io.weightQ.ready.peekBoolean() && - dut.io.weightK.ready.peekBoolean() - ) { - println("inputToken, weightQ and weightK are ready") - dut.io.inputToken.valid.poke(true.B) - dut.io.weightQ.valid.poke(true.B) - dut.io.weightK.valid.poke(true.B) - - for { - row <- 0 until m - col <- 0 until n - i <- 0 until k - } { - dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) - dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) - dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) - } - } else { - dut.io.inputToken.valid.poke(false.B) - dut.io.weightQ.valid.poke(false.B) - dut.io.weightK.valid.poke(false.B) + var caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val exQuery = mmul(inputToken, weightQ) + val exKey = mmul(inputToken, weightK) + val exScores = mmul(exQuery, exKey.transpose) + (inputToken, weightQ, weightK, exScores) } - while (!dut.io.scores.valid.peekBoolean()) { - dut.clock.step() - } - - dut.io.scores.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - - for { - row <- 0 until m - col <- 0 until m - } { - val outBigInt = dut.io.scores.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - checkResult(out, expected, row, col, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inputToken, weightQ, weightK, _) = testCases(cnt) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("test case " + cnt + ": inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) + } + cnt += 1 + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) + } + dut.clock.step() } - } - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, _, exScores) = testCases(resCnt) + if (dut.io.scores.valid.peekBoolean()) { + dut.io.scores.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until m + } { + val outBigInt = dut.io.scores.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.scores.ready.poke(false.B) + } + dut.clock.step() + } + }.join() } private def testAttnScoresSingle[T: Numeric: ClassTag]( @@ -286,68 +505,82 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { val n = dut.n val gemmType = dut.gemmType - val inputToken = matInit[T](m, k) - val weightQ = matInit[T](k, n) - val weightK = matInit[T](k, n) - val Query = mmul(inputToken, weightQ) - val Key = mmul(inputToken, weightK) - val expectedResults = mmul(Query, Key.transpose) - - print("Query:\n") - printmat(Query) - print("Key:\n") - printmat(Key) - print("expectedResults:\n") - printmat(expectedResults) - - if ( - dut.io.inputToken.ready.peekBoolean() && - dut.io.weightQ.ready.peekBoolean() && - dut.io.weightK.ready.peekBoolean() - ) { - println("inputToken, weightQ and weightK are ready") - dut.io.inputToken.valid.poke(true.B) - dut.io.weightQ.valid.poke(true.B) - dut.io.weightK.valid.poke(true.B) - - for { - row <- 0 until m - col <- 0 until n - i <- 0 until k - } { - dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) - dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) - dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) - } - } else { - dut.io.inputToken.valid.poke(false.B) - dut.io.weightQ.valid.poke(false.B) - dut.io.weightK.valid.poke(false.B) + var caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val exQuery = mmul(inputToken, weightQ) + val exKey = mmul(inputToken, weightK) + val exScores = mmul(exQuery, exKey.transpose) + (inputToken, weightQ, weightK, exScores) } - - val precision = 0.001f - var invalidcnt = 0 - - while (!dut.io.done.peekBoolean()) { - if (dut.io.curRowScores.valid.peekBoolean()) { - val currentRowIndex = dut.io.curRowScores.bits.index.peekInt() - println(s"currentRow index: $currentRowIndex") - - for (i <- 0 until m) { - val outBigInt = dut.io.curRowScores.bits.value(i).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(currentRowIndex.toInt)(i) - checkResult(out, expected, currentRowIndex.toInt, i, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inputToken, weightQ, weightK, _) = testCases(cnt) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("test case " + cnt + ": inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) } + cnt += 1 + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) } + dut.clock.step() } - dut.clock.step() - } - - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, _, exScores) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curRowScores.valid.peekBoolean()) { + dut.io.curRowScores.ready.poke(true.B) + val curRowIndex = dut.io.curRowScores.bits.index.peekInt() + println(s"curRow index: $curRowIndex") + dut.io.curRowScores.bits.index.expect(rowIdx.U) + for { + col <- 0 until m + } { + val outBigInt = dut.io.curRowScores.bits.value(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(rowIdx)(col) + checkResult(out, expected, rowIdx, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curRowScores.ready.poke(false.B) + } + dut.clock.step() + } + dut.io.done.expect(true.B) + resCnt += 1 + } + }.join() } private def testAttnScoresSingleQueue[T: Numeric: ClassTag]( @@ -360,80 +593,87 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { val n = dut.n val gemmType = dut.gemmType - val inputToken = matInit[T](m, k) - val weightQ = matInit[T](k, n) - val weightK = matInit[T](k, n) - val Query = mmul(inputToken, weightQ) - val Key = mmul(inputToken, weightK) - val expectedResults = mmul(Query, Key.transpose) - - print("Query:\n") - printmat(Query) - print("Key:\n") - printmat(Key) - print("expectedResults:\n") - printmat(expectedResults) - dut.io.flush.poke(true.B) - dut.clock.step(1) - dut.io.flush.poke(false.B) - if ( - dut.io.inputToken.ready.peekBoolean() && - dut.io.weightQ.ready.peekBoolean() && - dut.io.weightK.ready.peekBoolean() - ) { - println("inputToken, weightQ and weightK are ready") - dut.io.inputToken.valid.poke(true.B) - dut.io.weightQ.valid.poke(true.B) - dut.io.weightK.valid.poke(true.B) - - for { - row <- 0 until m - col <- 0 until n - i <- 0 until k - } { - dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) - dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) - dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) - } - } else { - dut.io.inputToken.valid.poke(false.B) - dut.io.weightQ.valid.poke(false.B) - dut.io.weightK.valid.poke(false.B) + var caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inputToken = matInit[T](m, k) + val weightQ = matInit[T](k, n) + val weightK = matInit[T](k, n) + val exQuery = mmul(inputToken, weightQ) + val exKey = mmul(inputToken, weightK) + val exScores = mmul(exQuery, exKey.transpose) + (inputToken, weightQ, weightK, exScores) } - - val precision = 0.001f - var invalidcnt = 0 - - while (!dut.io.done.peekBoolean()) { - if (dut.io.curRowScores.valid.peekBoolean()) { - val currentRowIndex = dut.io.curRowScores.bits.index.peekInt() - println(s"currentRow index: $currentRowIndex") - - for (i <- 0 until m) { - val outBigInt = dut.io.curRowScores.bits.value(i).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(currentRowIndex.toInt)(i) - checkResult(out, expected, currentRowIndex.toInt, i, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inputToken, weightQ, weightK, _) = testCases(cnt) + dut.io.flush.poke(true.B) + dut.clock.step() + dut.io.flush.poke(false.B) + if ( + dut.io.inputToken.ready.peekBoolean() && + dut.io.weightQ.ready.peekBoolean() && + dut.io.weightK.ready.peekBoolean() + ) { + println("test case " + cnt + ": inputToken, weightQ and weightK are ready") + dut.io.inputToken.valid.poke(true.B) + dut.io.weightQ.valid.poke(true.B) + dut.io.weightK.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.inputToken.bits(row)(i).poke(toBinaryBigInt(inputToken(row)(i)).U) + dut.io.weightQ.bits(i)(col).poke(toBinaryBigInt(weightQ(i)(col)).U) + dut.io.weightK.bits(i)(col).poke(toBinaryBigInt(weightK(i)(col)).U) } + cnt += 1 + } else { + dut.io.inputToken.valid.poke(false.B) + dut.io.weightQ.valid.poke(false.B) + dut.io.weightK.valid.poke(false.B) } + dut.clock.step() } - dut.clock.step() - } - - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, _, exScores) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curRowScores.valid.peekBoolean()) { + dut.io.curRowScores.ready.poke(true.B) + val curRowIndex = dut.io.curRowScores.bits.index.peekInt() + println(s"curRow index: $curRowIndex") + dut.io.curRowScores.bits.index.expect(rowIdx.U) + for { + col <- 0 until m + } { + val outBigInt = dut.io.curRowScores.bits.value(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exScores(rowIdx)(col) + checkResult(out, expected, rowIdx, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curRowScores.ready.poke(false.B) + } + dut.clock.step() + } + dut.io.done.expect(true.B) + resCnt += 1 + } + }.join() } - // "AttnScoresSingleQueue " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new AttnScoresSingleQueue(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testAttnScoresSingleQueue[Int](dut) - // } - // } - "AttnScoresSingleQueue " should "compute fxp matrix multiplication" in { implicit val config: DataWidthConfig = FxpConfig test(new AttnScoresSingleQueue(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) @@ -442,6 +682,14 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { } } + // "AttnScoresSingleQueue " should "compute fp32 matrix multiplication" in { + // implicit val config: DataWidthConfig = Fp32Config + // test(new AttnScoresSingleQueue(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fp32)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testAttnScoresSingleQueue[Float](dut) + // } + // } + // "AttnScoresSingle " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig // test(new AttnScoresSingle(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) @@ -457,27 +705,22 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { // testAttnScoresSingle[Float](dut) // } // } - // "QKMul " should "compute fxp matrix multiplication" in { + + // "QKMulSingle " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new QKMul(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // test(new QKMulSingle(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testQKMul[Int](dut) + // testQKMulSingle[Int](dut) // } // } - // "QKMulSingle " should "compute fxp matrix multiplication" in { + + // "AttnScores " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new QKMulSingle(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // test(new AttnScores(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testQKMulSingle[Int](dut) + // testAttnScores[Int](dut) // } // } -// "AttnScores " should "compute fxp matrix multiplication" in { -// implicit val config: DataWidthConfig = FxpConfig -// test(new AttnScores(m = 4, k = 4, n = 4, peCount = 4, gemmType = GEMMDataType.Fxp)) -// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => -// testAttnScores[Int](dut) -// } -// } // "AttnScores " should "compute fp32 matrix multiplication" in { // implicit val config: DataWidthConfig = Fp32Config @@ -495,13 +738,13 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { // } // } -// "QKMulWithReg " should "compute fxp matrix multiplication" in { -// implicit val config: DataWidthConfig = FxpConfig -// test(new QKMulWithReg(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) -// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => -// testQKMulWithReg[Int](dut) -// } -// } + // "QKMulWithReg " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKMulWithReg(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKMulWithReg[Int](dut) + // } + // } // "QKGen " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig @@ -511,11 +754,11 @@ class AttnScoresTest extends AnyFlatSpec with ChiselScalatestTester { // } // } -// "QKGenWithReg " should "compute fxp matrix multiplication" in { -// implicit val config: DataWidthConfig = FxpConfig -// test(new QKGenWithReg(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) -// .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => -// testQKGenWithReg[Int](dut) -// } -// } + // "QKGenWithReg " should "compute fxp matrix multiplication" in { + // implicit val config: DataWidthConfig = FxpConfig + // test(new QKGenWithReg(m = 8, k = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + // testQKGenWithReg[Int](dut) + // } + // } } diff --git a/src/test/scala/kernel/alu/GemmFMATest.scala b/src/test/scala/kernel/alu/GemmFMATest.scala index 94b8f60..c017c42 100644 --- a/src/test/scala/kernel/alu/GemmFMATest.scala +++ b/src/test/scala/kernel/alu/GemmFMATest.scala @@ -13,19 +13,6 @@ import Utils._ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTestExecution { - // private trait FMADut { - // def k: Int - // def peCount: Int - // def gemmType: GEMMDataType.Type - // def io: Bundle { - // val reset: Bool - // val matrixA_row: DecoupledIO[Vec[UInt]] - // val matrixB_cols: DecoupledIO[Vec[Vec[UInt]]] - // val blockResult: DecoupledIO[Vec[UInt]] - // } - // def clock: Clock - // } - private def testMultiFMA[T: Numeric: ClassTag]( dut: MultiFMA )( @@ -153,58 +140,69 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val m = dut.m val k = dut.k val n = dut.n - val peCount = dut.peCount val gemmType = dut.gemmType - val matrixA = matInit[T](m, k) - val matrixB = matInit[T](k, n) - val expectedResults = mmul(matrixA, matrixB) - - printmat(matrixA) - printmat(matrixB) - printmat(expectedResults) - - if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { - println("matrixA and matrixB are ready") - dut.io.matrixA.valid.poke(true.B) - dut.io.matrixB.valid.poke(true.B) - - for { - row <- 0 until m - col <- 0 until n - i <- 0 until k - } { - dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) - dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) - } - } else { - dut.io.matrixA.valid.poke(false.B) - dut.io.matrixB.valid.poke(false.B) - } + val caseNum = 10 - while (!dut.io.results.valid.peekBoolean()) { - dut.clock.step() + val testCases = Array.tabulate(caseNum) { i => + val matrixA = matInit[T](m, k) + val matrixB = matInit[T](k, n) + val exResult = mmul(matrixA, matrixB) + (matrixA, matrixB, exResult) } - dut.io.results.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - - for { - row <- 0 until m - col <- 0 until n - } { - val outBigInt = dut.io.results.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - checkResult(out, expected, row, col, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + fork { + var cnt = 0 + while (cnt < caseNum) { + val (matrixA, matrixB, _) = testCases(cnt) + if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { + println(s"case $cnt : matrixA and matrixB are ready") + dut.io.matrixA.valid.poke(true.B) + dut.io.matrixB.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) + } + cnt += 1 + } else { + dut.io.matrixA.valid.poke(false.B) + dut.io.matrixB.valid.poke(false.B) + } + dut.clock.step() } - } - - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exResult) = testCases(resCnt) + if (dut.io.results.valid.peekBoolean()) { + dut.io.results.ready.poke(true.B) + val precision = 0.001f + var invalidcnt = 0 + for { + row <- 0 until m + col <- 0 until n + } { + val outBigInt = dut.io.results.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exResult(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt Verification passed!") + else println(s"case $resCnt Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.results.ready.poke(false.B) + } + dut.clock.step() + } + }.join() } private def testGEMMFMASingle[T: Numeric: ClassTag]( @@ -218,56 +216,72 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val peCount = dut.peCount val gemmType = dut.gemmType - val matrixA = matInit[T](m, k) - val matrixB = matInit[T](k, n) - val expectedResults = mmul(matrixA, matrixB) - - printmat(matrixA) - printmat(matrixB) - printmat(expectedResults) + val caseNum = 10 - if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { - println("matrixA and matrixB are ready") - dut.io.matrixA.valid.poke(true.B) - dut.io.matrixB.valid.poke(true.B) - - for { - row <- 0 until m - col <- 0 until n - i <- 0 until k - } { - dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) - dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) - } - } else { - dut.io.matrixA.valid.poke(false.B) - dut.io.matrixB.valid.poke(false.B) + val testCases = Array.tabulate(caseNum) { i => + val matrixA = matInit[T](m, k) + val matrixB = matInit[T](k, n) + val exResult = mmul(matrixA, matrixB) + (matrixA, matrixB, exResult) } - val precision = 0.001f - var invalidcnt = 0 - - while (!dut.io.done.peekBoolean()) { - if (dut.io.curRow.valid.peekBoolean()) { - val curRowIndex = dut.io.curRow.bits.index.peekInt() - println(s"curRow index: $curRowIndex") - - for (i <- 0 until n) { - val outBigInt = dut.io.curRow.bits.value(i).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(curRowIndex.toInt)(i) - - checkResult(out, expected, curRowIndex.toInt, i, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + fork { + var cnt = 0 + while (cnt < caseNum) { + val (matrixA, matrixB, _) = testCases(cnt) + if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { + println(s"case $cnt : matrixA and matrixB are ready") + dut.io.matrixA.valid.poke(true.B) + dut.io.matrixB.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) } + cnt += 1 + } else { + dut.io.matrixA.valid.poke(false.B) + dut.io.matrixB.valid.poke(false.B) } + dut.clock.step() } - dut.clock.step() - } - - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exResult) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curRow.valid.peekBoolean()) { + dut.io.curRow.ready.poke(true.B) + val curRowIndex = dut.io.curRow.bits.index.peekInt() + dut.io.curRow.bits.index.expect(rowIdx.U) + println(s"curRow index: $curRowIndex") + for (i <- 0 until n) { + val outBigInt = dut.io.curRow.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exResult(curRowIndex.toInt)(i) + checkResult(out, expected, rowIdx, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curRow.ready.poke(false.B) + } + dut.clock.step() + } + dut.io.done.expect(true.B) + resCnt += 1 + } + }.join() } private def testGEMMSingleQueue[T: Numeric: ClassTag]( @@ -280,70 +294,83 @@ class GEMMFMATest extends AnyFlatSpec with ChiselScalatestTester with ParallelTe val n = dut.n val gemmType = dut.gemmType - val matrixA = matInit[T](m, k) - val matrixB = matInit[T](k, n) - val expectedResults = mmul(matrixA, matrixB) + val caseNum = 10 - printmat(matrixA) - printmat(matrixB) - printmat(expectedResults) - - dut.io.flush.poke(true.B) - dut.clock.step(1) - dut.io.flush.poke(false.B) - - if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { - println("matrixA and matrixB are ready") - dut.io.matrixA.valid.poke(true.B) - dut.io.matrixB.valid.poke(true.B) - - for { - row <- 0 until m - col <- 0 until n - i <- 0 until k - } { - dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) - dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) - } - } else { - dut.io.matrixA.valid.poke(false.B) - dut.io.matrixB.valid.poke(false.B) + val testCases = Array.tabulate(caseNum) { i => + val matrixA = matInit[T](m, k) + val matrixB = matInit[T](k, n) + val exResult = mmul(matrixA, matrixB) + (matrixA, matrixB, exResult) } - - dut.io.curRow.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - - while (!dut.io.done.peekBoolean()) { - if (dut.io.curRow.valid.peekBoolean()) { - val curRowIndex = dut.io.curRow.bits.index.peekInt() - println(s"curRow index: $curRowIndex") - - for (i <- 0 until n) { - val outBigInt = dut.io.curRow.bits.value(i).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(curRowIndex.toInt)(i) - - checkResult(out, expected, curRowIndex.toInt, i, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + fork { + var cnt = 0 + while (cnt < caseNum) { + dut.io.flush.poke(true.B) + dut.clock.step() + dut.io.flush.poke(false.B) + val (matrixA, matrixB, _) = testCases(cnt) + if (dut.io.matrixA.ready.peekBoolean() && dut.io.matrixB.ready.peekBoolean()) { + println(s"case $cnt : matrixA and matrixB are ready") + dut.io.matrixA.valid.poke(true.B) + dut.io.matrixB.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until k + } { + dut.io.matrixA.bits(row)(i).poke(toBinaryBigInt(matrixA(row)(i)).U) + dut.io.matrixB.bits(i)(col).poke(toBinaryBigInt(matrixB(i)(col)).U) } + cnt += 1 + } else { + dut.io.matrixA.valid.poke(false.B) + dut.io.matrixB.valid.poke(false.B) } + dut.clock.step() } - dut.clock.step() - } - - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exResult) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curRow.valid.peekBoolean()) { + dut.io.curRow.ready.poke(true.B) + val curRowIndex = dut.io.curRow.bits.index.peekInt() + dut.io.curRow.bits.index.expect(rowIdx.U) + println(s"curRow index: $curRowIndex") + for (i <- 0 until n) { + val outBigInt = dut.io.curRow.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exResult(curRowIndex.toInt)(i) + checkResult(out, expected, rowIdx, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curRow.ready.poke(false.B) + } + dut.clock.step() + } + dut.io.done.expect(true.B) + resCnt += 1 + } + }.join() } - // "GEMMSingleQueue " should "compute fxp matrix multiplication" in { - // implicit val config: DataWidthConfig = FxpConfig - // test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) - // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => - // testGEMMSingleQueue[Int](dut) - // } - // } + "GEMMSingleQueue " should "compute fxp matrix multiplication" in { + implicit val config: DataWidthConfig = FxpConfig + test(new GEMMSingleQueue(m = 8, k = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) + .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => + testGEMMSingleQueue[Int](dut) + } + } // "GEMMSingleQueue " should "compute fp32 matrix multiplication" in { // implicit val config: DataWidthConfig = Fp32Config diff --git a/src/test/scala/kernel/alu/OutValueTest.scala b/src/test/scala/kernel/alu/OutValueTest.scala index 733ba02..139cae7 100644 --- a/src/test/scala/kernel/alu/OutValueTest.scala +++ b/src/test/scala/kernel/alu/OutValueTest.scala @@ -20,55 +20,66 @@ class OutValueTest extends AnyFlatSpec with ChiselScalatestTester with ParallelT val n = dut.n val gemmType = dut.gemmType - val Scores = matInit[T](m, m) - val Value = matInit[T](m, n) - val expectedResults = mmul(Scores, Value) - - printmat(Scores) - printmat(Value) - printmat(expectedResults) - - if (dut.io.Scores.ready.peekBoolean() && dut.io.Value.ready.peekBoolean()) { - println("Scores and Value are ready") - dut.io.Scores.valid.poke(true.B) - dut.io.Value.valid.poke(true.B) - - for { - row <- 0 until m - col <- 0 until n - i <- 0 until m - } { - dut.io.Scores.bits(row)(i).poke(toBinaryBigInt(Scores(row)(i)).U) - dut.io.Value.bits(i)(col).poke(toBinaryBigInt(Value(i)(col)).U) - } - } else { - dut.io.Scores.valid.poke(false.B) - dut.io.Value.valid.poke(false.B) - } - - while (!dut.io.AttnOut.valid.peekBoolean()) { - dut.clock.step() + var caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inScores = matInit[T](m, m) + val inValue = matInit[T](m, n) + val exAttnOut = mmul(inScores, inValue) + (inScores, inValue, exAttnOut) } - - dut.io.AttnOut.ready.poke(true.B) - val precision = 0.001f - var invalidcnt = 0 - - for { - row <- 0 until m - col <- 0 until n - } { - val outBigInt = dut.io.AttnOut.bits(row)(col).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(row)(col) - checkResult(out, expected, row, col, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inScores, inValue, _) = testCases(cnt) + if (dut.io.Scores.ready.peekBoolean() && dut.io.Value.ready.peekBoolean()) { + println("test case " + cnt + ": Scores and Value are ready") + dut.io.Scores.valid.poke(true.B) + dut.io.Value.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + i <- 0 until m + } { + dut.io.Scores.bits(row)(i).poke(toBinaryBigInt(inScores(row)(i)).U) + dut.io.Value.bits(i)(col).poke(toBinaryBigInt(inValue(i)(col)).U) + } + cnt += 1 + } else { + dut.io.Scores.valid.poke(false.B) + dut.io.Value.valid.poke(false.B) + } + dut.clock.step() } - } - - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exAttnOut) = testCases(resCnt) + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.AttnOut.valid.peekBoolean()) { + dut.io.AttnOut.ready.poke(true.B) + val (_, _, exAttnOut) = testCases(resCnt) + for { + row <- 0 until m + col <- 0 until n + } { + val outBigInt = dut.io.AttnOut.bits(row)(col).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exAttnOut(row)(col) + checkResult(out, expected, row, col, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println("case " + resCnt + ": Verification passed!") + else println(s"case $resCnt : Verification failed with $invalidcnt errors.") + resCnt += 1 + } else { + dut.io.AttnOut.ready.poke(false.B) + } + dut.clock.step() + } + }.join() } private def testOutValueSingle[T: Numeric: ClassTag]( @@ -80,70 +91,84 @@ class OutValueTest extends AnyFlatSpec with ChiselScalatestTester with ParallelT val n = dut.n val gemmType = dut.gemmType - val AttnWeights = matInit[T](m, m) - val Value = matInit[T](m, n) - val expectedResults = mmul(AttnWeights, Value) - - printmat(AttnWeights) - printmat(Value) - printmat(expectedResults) - - val precision = 0.001f - var invalidcnt = 0 - - if (dut.io.Value.ready.peekBoolean()) { - println("Value is ready") - dut.io.Value.valid.poke(true.B) - - for { - i <- 0 until m - j <- 0 until n - } { - dut.io.Value.bits(i)(j).poke(toBinaryBigInt(Value(i)(j)).U) - } - } else { - dut.io.Value.valid.poke(false.B) + var caseNum = 10 + val testCases = Array.tabulate(caseNum) { i => + val inScores = matInit[T](m, m) + val inValue = matInit[T](m, n) + val exAttnOut = mmul(inScores, inValue) + (inScores, inValue, exAttnOut) } - for (index <- 0 until m) { - if (dut.io.curScores.ready.peekBoolean()) { - println(s"curScores index: $index is ready") - dut.io.curScores.valid.poke(true.B) - - for (i <- 0 until m) { - dut.io.curScores.bits.value(i).poke(toBinaryBigInt(AttnWeights(index)(i)).U) + fork { + var cnt = 0 + while (cnt < caseNum) { + val (inScores, inValue, _) = testCases(cnt) + if (dut.io.Value.ready.peekBoolean()) { + println("test case " + cnt + ": Value is ready") + dut.io.Value.valid.poke(true.B) + for { + row <- 0 until m + col <- 0 until n + } { + dut.io.Value.bits(row)(col).poke(toBinaryBigInt(inValue(row)(col)).U) + } + cnt += 1 + } else { + dut.io.Value.valid.poke(false.B) + } + var rowIdx = 0 + while (rowIdx < m) { + if (dut.io.curScores.ready.peekBoolean()) { + println(s"curScores index: $rowIdx is ready") + dut.io.curScores.valid.poke(true.B) + for (i <- 0 until m) { + dut.io.curScores.bits.value(i).poke(toBinaryBigInt(inScores(rowIdx)(i)).U) + } + } else { + dut.io.curScores.valid.poke(false.B) + } + dut.clock.step() + rowIdx += 1 } - } else { - dut.io.curScores.valid.poke(false.B) - } - - dut.io.curAttnOut.ready.poke(false.B) - while (!dut.io.curAttnOut.valid.peekBoolean()) { - dut.clock.step() } - - dut.io.curAttnOut.ready.poke(true.B) - val curRowIndex = dut.io.curAttnOut.bits.index.peekInt() - - for (i <- 0 until n) { - val outBigInt = dut.io.curAttnOut.bits.value(i).peekInt() - val out = fromBinaryBigInt[T](outBigInt) - val expected = expectedResults(curRowIndex.toInt)(i) - - checkResult(out, expected, curRowIndex.toInt, i, precision) match { - case Some(_) => invalidcnt += 1 - case None => // right + }.fork { + var resCnt = 0 + while (resCnt < caseNum) { + val (_, _, exAttnOut) = testCases(resCnt) + var rowIdx = 0 + while (rowIdx < m) { + val precision = 0.001f + var invalidcnt = 0 + if (dut.io.curAttnOut.ready.peekBoolean()) { + dut.io.curAttnOut.ready.poke(true.B) + val curRowIndex = dut.io.curAttnOut.bits.index.peekInt() + println(s"curRow index: $curRowIndex") + dut.io.curAttnOut.bits.index.poke(rowIdx.U) + for (i <- 0 until n) { + val outBigInt = dut.io.curAttnOut.bits.value(i).peekInt() + val out = fromBinaryBigInt[T](outBigInt) + val expected = exAttnOut(rowIdx)(i) + val precision = 0.001f + checkResult(out, expected, rowIdx, i, precision) match { + case Some(_) => invalidcnt += 1 + case None => // right + } + } + if (invalidcnt == 0) println(s"case $resCnt : row $rowIdx Verification passed!") + else println(s"case $resCnt : row $rowIdx Verification failed with $invalidcnt errors.") + rowIdx += 1 + } else { + dut.io.curAttnOut.ready.poke(false.B) + } + dut.clock.step() } + dut.io.done.expect(true.B) + resCnt += 1 } - dut.clock.step() - } - - dut.io.done.expect(true.B) - - if (invalidcnt == 0) println("Verification passed!") - else println(s"Verification failed with $invalidcnt errors.") + }.join() } + //TODO:OutValueSingle Test ERROR "OutValueSingle " should "compute fxp matrix multiplication" in { implicit val config: DataWidthConfig = FxpConfig test(new OutValueSingle(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) @@ -154,7 +179,7 @@ class OutValueTest extends AnyFlatSpec with ChiselScalatestTester with ParallelT // "OutValue " should "compute fxp matrix multiplication" in { // implicit val config: DataWidthConfig = FxpConfig - // test(new OutValue(m = 8, n = 8, peCount = 4, gemmType = GEMMDataType.Fxp)) + // test(new OutValue(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) // .withAnnotations(Seq(VerilatorBackendAnnotation, WriteVcdAnnotation)) { dut => // testOutValue[Int](dut) // } From d1f806af3beb7fc7d210b04774bcfa14cdbe719b Mon Sep 17 00:00:00 2001 From: pyfirstcsh <8295488+cao-shuai-hu@user.noreply.gitee.com> Date: Sun, 19 Jan 2025 21:55:59 +0800 Subject: [PATCH 08/10] fix OutValue and all passed but some TODO should fix --- src/main/scala/kernel/alu/GemmFMA.scala | 7 ++++ src/main/scala/kernel/alu/OutValue.scala | 35 ++++++++++++++------ src/test/scala/kernel/alu/OutValueTest.scala | 16 +++++---- 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/src/main/scala/kernel/alu/GemmFMA.scala b/src/main/scala/kernel/alu/GemmFMA.scala index caf0313..d57cfe4 100644 --- a/src/main/scala/kernel/alu/GemmFMA.scala +++ b/src/main/scala/kernel/alu/GemmFMA.scala @@ -184,6 +184,13 @@ class GEMMFMATotal( // cases "Exception in thread "chiseltest_thread_2" java.lang.RuntimeException: Deadlock!" Error // when test fork() and join() in chiseltest + + //TODO: + // all the moudle use readyReg can case deadlock + // but use true.B ,false.B can't + // when fork() and join() in chiseltest + // maybe something wrong not find + // val readyReg = RegInit(true.B) // val resValid = RegInit(false.B) // io.matrixA.ready := readyReg diff --git a/src/main/scala/kernel/alu/OutValue.scala b/src/main/scala/kernel/alu/OutValue.scala index d8d0bf8..2c875ad 100644 --- a/src/main/scala/kernel/alu/OutValue.scala +++ b/src/main/scala/kernel/alu/OutValue.scala @@ -110,7 +110,7 @@ class OutValueSingle( val multiFMA = Module(new MultiFMA(m, peCount, gemmType)) - val rowIndex = Counter(m) + // val rowIndex = Counter(m) val colIndex = Counter(n / peCount) multiFMA.io.matrixA_row.valid := io.curScores.valid @@ -127,9 +127,10 @@ class OutValueSingle( multiFMA.io.blockResult.ready := false.B val curRowReg = Reg(Vec(n, UInt(config.outputWidth.W))) + val curRowIndex = RegInit(0.U) object state extends ChiselEnum { - val idle, compute, update, load, done = Value + val idle, compute, update, output, load, done = Value } val stateReg = RegInit(state.idle) @@ -144,40 +145,54 @@ class OutValueSingle( } is(state.compute) { io.curScores.ready := false.B + io.Value.ready := false.B + multiFMA.io.reset := false.B multiFMA.io.blockResult.ready := true.B + curRowIndex := io.curScores.bits.index when(multiFMA.io.blockResult.valid) { for (i <- 0 until peCount) { - curRowReg(colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0) := multiFMA.io.blockResult.bits(i) + curRowReg((colIndex.value * peCount.U + i.U)(log2Ceil(n) - 1, 0)) := multiFMA.io.blockResult.bits(i) } stateReg := state.update } } is(state.update) { + io.curScores.ready := false.B + io.Value.ready := false.B multiFMA.io.reset := true.B multiFMA.io.blockResult.ready := false.B - io.curAttnOut.valid := false.B + // io.curAttnOut.valid := false.B when(colIndex.inc()) { io.curAttnOut.valid := true.B - io.curAttnOut.bits.index := rowIndex.value + io.curAttnOut.bits.index := curRowIndex io.curAttnOut.bits.value := curRowReg - when(rowIndex.inc()) { - stateReg := state.done - }.otherwise { - stateReg := state.load - } + stateReg := state.output }.otherwise { stateReg := state.compute } } + is(state.output) { + io.curScores.ready := false.B + io.Value.ready := false.B + when(io.curScores.bits.index === (m - 1).U) { + // print(p"curScores.ready: ${io.curScores.ready}") + stateReg := state.done + }.otherwise { + stateReg := state.load + } + } is(state.load) { io.curScores.ready := true.B + io.Value.ready := false.B + io.curAttnOut.valid := false.B stateReg := state.compute } is(state.done) { io.done := true.B io.Value.ready := true.B io.curScores.ready := true.B + io.curAttnOut.valid := false.B stateReg := state.idle } } diff --git a/src/test/scala/kernel/alu/OutValueTest.scala b/src/test/scala/kernel/alu/OutValueTest.scala index 139cae7..f6e8481 100644 --- a/src/test/scala/kernel/alu/OutValueTest.scala +++ b/src/test/scala/kernel/alu/OutValueTest.scala @@ -119,16 +119,17 @@ class OutValueTest extends AnyFlatSpec with ChiselScalatestTester with ParallelT var rowIdx = 0 while (rowIdx < m) { if (dut.io.curScores.ready.peekBoolean()) { - println(s"curScores index: $rowIdx is ready") + println(s"case ${cnt-1} curScores index: $rowIdx is ready") dut.io.curScores.valid.poke(true.B) for (i <- 0 until m) { dut.io.curScores.bits.value(i).poke(toBinaryBigInt(inScores(rowIdx)(i)).U) } + dut.io.curScores.bits.index.poke(rowIdx.U) + rowIdx += 1 } else { dut.io.curScores.valid.poke(false.B) } dut.clock.step() - rowIdx += 1 } } }.fork { @@ -139,11 +140,10 @@ class OutValueTest extends AnyFlatSpec with ChiselScalatestTester with ParallelT while (rowIdx < m) { val precision = 0.001f var invalidcnt = 0 - if (dut.io.curAttnOut.ready.peekBoolean()) { + if (dut.io.curAttnOut.valid.peekBoolean()) { dut.io.curAttnOut.ready.poke(true.B) val curRowIndex = dut.io.curAttnOut.bits.index.peekInt() - println(s"curRow index: $curRowIndex") - dut.io.curAttnOut.bits.index.poke(rowIdx.U) + dut.io.curAttnOut.bits.index.expect(rowIdx.U) for (i <- 0 until n) { val outBigInt = dut.io.curAttnOut.bits.value(i).peekInt() val out = fromBinaryBigInt[T](outBigInt) @@ -162,13 +162,17 @@ class OutValueTest extends AnyFlatSpec with ChiselScalatestTester with ParallelT } dut.clock.step() } + dut.clock.step() dut.io.done.expect(true.B) + println(s"case $resCnt done\n") resCnt += 1 } }.join() } - //TODO:OutValueSingle Test ERROR + //TODO: Fix the warning in the OutValueSingle module + // warning: This module has an additional loop for the curScores input. + // The error might be caused by the valid signal of the input"curScores", but it hasn't been resolved yet. "OutValueSingle " should "compute fxp matrix multiplication" in { implicit val config: DataWidthConfig = FxpConfig test(new OutValueSingle(m = 8, n = 12, peCount = 4, gemmType = GEMMDataType.Fxp)) From 8ef89cf0d1afe4a274345a19041e5e30d4211e05 Mon Sep 17 00:00:00 2001 From: pyfirstcsh <8295488+cao-shuai-hu@user.noreply.gitee.com> Date: Sun, 16 Feb 2025 18:59:38 +0800 Subject: [PATCH 09/10] chage verilator version --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 28553de..86c9550 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,7 +8,7 @@ on: workflow_dispatch: env: - verilator-version: v4.202 + verilator-version: v5.020 verilator-install-dir: verilator-install jobs: From 5c8b0e6176e354faf8fd8be3633f70a0f72291d6 Mon Sep 17 00:00:00 2001 From: pyfirstcsh <8295488+cao-shuai-hu@user.noreply.gitee.com> Date: Wed, 19 Feb 2025 20:48:00 +0800 Subject: [PATCH 10/10] change gcc version --- .github/workflows/test.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 86c9550..aed9351 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,7 +8,7 @@ on: workflow_dispatch: env: - verilator-version: v5.020 + verilator-version: v4.202 verilator-install-dir: verilator-install jobs: @@ -29,7 +29,8 @@ jobs: apps: sbt - name: Setup Dependencies run: | - sudo apt-get install ccache + sudo apt-get install ccache g++-11 # 安装 GCC 11 + sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 11 - name: Get Cached Verilator id: get-cached-verilator uses: actions/cache@v4