diff --git a/include/core/allocator.h b/include/core/allocator.h index 002601d..7a93aad 100644 --- a/include/core/allocator.h +++ b/include/core/allocator.h @@ -26,6 +26,11 @@ namespace infini { // =================================== 作业 =================================== // TODO:可能需要设计一个数据结构来存储free block,以便于管理和合并 // HINT: 可以使用一个 map 来存储 free block,key 为 block 的起始/结尾地址,value 为 block 的大小 + // Free block management: + // - freeBlocksByAddr: map from start address to size, for allocation search + // - freeBlocksByEnd: map from end address to size, for merging adjacent blocks + std::map freeBlocksByAddr; // key: start address, value: size + std::map freeBlocksByEnd; // key: end address, value: size // =================================== 作业 =================================== public: diff --git a/src/core/allocator.cc b/src/core/allocator.cc index ff593ae..e12eb47 100644 --- a/src/core/allocator.cc +++ b/src/core/allocator.cc @@ -31,6 +31,45 @@ namespace infini // =================================== 作业 =================================== // TODO: 设计一个算法来分配内存,返回起始地址偏移量 + // Use First Fit algorithm: find the first free block that is large enough + for (auto it = freeBlocksByAddr.begin(); it != freeBlocksByAddr.end(); ++it) + { + size_t addr = it->first; + size_t blockSize = it->second; + + if (blockSize >= size) + { + // Found a suitable block + // Remove this block from both maps + freeBlocksByAddr.erase(it); + freeBlocksByEnd.erase(addr + blockSize); + + // If the block is larger than needed, add the remaining part back + if (blockSize > size) + { + size_t newAddr = addr + size; + size_t newSize = blockSize - size; + freeBlocksByAddr[newAddr] = newSize; + freeBlocksByEnd[newAddr + newSize] = newSize; + } + + // Update memory usage statistics + used += size; + if (used > peak) + { + peak = used; + } + + return addr; + } + } + + // No suitable free block found, allocate at the end + size_t addr = peak; + used += size; + peak += size; + + return addr; // =================================== 作业 =================================== return 0; @@ -43,6 +82,75 @@ namespace infini // =================================== 作业 =================================== // TODO: 设计一个算法来回收内存 + // Update memory usage + used -= size; + + size_t blockStart = addr; + size_t blockEnd = addr + size; + size_t blockSize = size; + + // Special case: if freeing the block at the end, just reduce peak + if (blockEnd == peak) + { + // Check if we can merge with a previous free block at the end + auto prevIt = freeBlocksByEnd.find(blockStart); + if (prevIt != freeBlocksByEnd.end()) + { + // Merge with the previous block and reduce peak further + size_t prevSize = prevIt->second; + size_t prevStart = blockStart - prevSize; + + // Remove the previous block from both maps + freeBlocksByAddr.erase(prevStart); + freeBlocksByEnd.erase(blockStart); + + // Reduce peak to the start of the merged block + peak = prevStart; + } + else + { + // Just reduce peak + peak = blockStart; + } + return; + } + + // Try to merge with the previous adjacent free block + auto prevIt = freeBlocksByEnd.find(blockStart); + if (prevIt != freeBlocksByEnd.end()) + { + // Found a previous adjacent block + size_t prevSize = prevIt->second; + size_t prevStart = blockStart - prevSize; + + // Remove the previous block from both maps + freeBlocksByAddr.erase(prevStart); + freeBlocksByEnd.erase(blockStart); + + // Merge: extend the current block backwards + blockStart = prevStart; + blockSize += prevSize; + } + + // Try to merge with the next adjacent free block + auto nextIt = freeBlocksByAddr.find(blockEnd); + if (nextIt != freeBlocksByAddr.end()) + { + // Found a next adjacent block + size_t nextSize = nextIt->second; + + // Remove the next block from both maps + freeBlocksByAddr.erase(blockEnd); + freeBlocksByEnd.erase(blockEnd + nextSize); + + // Merge: extend the current block forwards + blockSize += nextSize; + blockEnd += nextSize; + } + + // Add the merged (or original) free block to both maps + freeBlocksByAddr[blockStart] = blockSize; + freeBlocksByEnd[blockEnd] = blockSize; // =================================== 作业 =================================== } diff --git a/src/core/graph.cc b/src/core/graph.cc index 3a90637..8446c0c 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -1,4 +1,6 @@ #include "core/graph.h" +#include "operators/matmul.h" +#include "operators/transpose.h" #include #include #include @@ -106,6 +108,196 @@ namespace infini // 1. 去除冗余的算子(例如,两个相邻的算子都是 transpose 算子,且做的是相反的操作,可以将其全部删除) // 2. 合并算子(例如,矩阵乘算子中含有属性transA、transB,如果其输入存在transpose,且对最后两个维度做交换,就可以将transpose融入到矩阵乘算子的属性中去) // =================================== 作业 =================================== + + // 辅助lambda:检查transpose是否只交换最后两个维度 + auto isSwapLastTwo = [](const vector &perm) -> bool { + int rank = perm.size(); + if (rank < 2) return false; + for (int i = 0; i < rank - 2; ++i) { + if (perm[i] != i) return false; + } + return perm[rank - 2] == rank - 1 && perm[rank - 1] == rank - 2; + }; + + // 辅助lambda:检查两个permute是否互为逆操作 + auto isInversePermute = [](const vector &perm1, const vector &perm2) -> bool { + if (perm1.size() != perm2.size()) return false; + for (size_t i = 0; i < perm1.size(); ++i) { + if (perm2[perm1[i]] != (int)i) return false; + } + return true; + }; + + // 辅助lambda:在ops中查找operator + auto findInOps = [this](const Operator &target) -> bool { + return std::find(ops.begin(), ops.end(), target) != ops.end(); + }; + + // 辅助lambda:从所有其他op中清理对指定op的引用 + auto cleanupOpReferences = [this](const Operator &opToRemove) { + for (auto &otherOp : ops) { + if (otherOp != opToRemove) { + otherOp->removePredecessors(opToRemove); + otherOp->removeSuccessors(opToRemove); + } + } + }; + + bool changed = true; + while (changed) + { + changed = false; + + // 规则1: 去除冗余的transpose (两个连续transpose且互为逆操作) + for (size_t idx = 0; idx < ops.size() && !changed; ++idx) + { + auto op = ops[idx]; // 复制,不是引用 + if (op->getOpType() != OpType::Transpose) + continue; + + auto transposeOp = as(op); + auto middleTensor = op->getInputs(0); + auto source = middleTensor->getSource(); + + if (!source || !findInOps(source) || source->getOpType() != OpType::Transpose) + continue; + + auto prevTranspose = as(source); + auto perm1 = prevTranspose->getPermute(); + auto perm2 = transposeOp->getPermute(); + + if (!isInversePermute(perm1, perm2)) + continue; + + // 两个transpose互为逆操作,可以消除 + auto originalInput = prevTranspose->getInputs(0); + auto finalOutput = transposeOp->getOutput(); + + // 找出使用finalOutput的所有算子(直接遍历ops) + for (auto &succOp : ops) + { + if (succOp == op || succOp == source) continue; + + bool needsUpdate = false; + for (auto &in : succOp->getInputs()) + { + if (in == finalOutput) + { + needsUpdate = true; + break; + } + } + + if (needsUpdate) + { + succOp->replaceInput(finalOutput, originalInput); + originalInput.get()->addTarget(succOp); + } + } + + // 清理连接 + originalInput.get()->removeTarget(source); + + // 清理其他op对这两个op的引用 + cleanupOpReferences(source); + cleanupOpReferences(op); + + // 删除 + removeTensor(middleTensor); + removeTensor(finalOutput); + removeOperator(source); + removeOperator(op); + + changed = true; + } + + if (changed) + continue; + + // 规则2: 将transpose融入matmul + for (size_t idx = 0; idx < ops.size() && !changed; ++idx) + { + auto op = ops[idx]; + if (op->getOpType() != OpType::MatMul) + continue; + + auto matmulOp = as(op); + + // 检查输入A + auto inputA = matmulOp->getInputs(0); + auto sourceA = inputA->getSource(); + if (sourceA && findInOps(sourceA) && sourceA->getOpType() == OpType::Transpose) + { + auto transposeA = as(sourceA); + auto permA = transposeA->getPermute(); + + // 计算使用inputA的op数量 + int useCount = 0; + for (auto &checkOp : ops) + { + for (auto &in : checkOp->getInputs()) + { + if (in == inputA) useCount++; + } + } + + if (isSwapLastTwo(permA) && useCount == 1) + { + auto originalInput = transposeA->getInputs(0); + matmulOp->replaceInput(inputA, originalInput); + originalInput.get()->addTarget(matmulOp); + originalInput.get()->removeTarget(sourceA); + matmulOp->setTransA(!matmulOp->getTransA()); + + // 清理其他op对transposeA的引用 + cleanupOpReferences(sourceA); + + removeTensor(inputA); + removeOperator(sourceA); + + changed = true; + continue; + } + } + + // 检查输入B + auto inputB = matmulOp->getInputs(1); + auto sourceB = inputB->getSource(); + if (sourceB && findInOps(sourceB) && sourceB->getOpType() == OpType::Transpose) + { + auto transposeB = as(sourceB); + auto permB = transposeB->getPermute(); + + // 计算使用inputB的op数量 + int useCount = 0; + for (auto &checkOp : ops) + { + for (auto &in : checkOp->getInputs()) + { + if (in == inputB) useCount++; + } + } + + if (isSwapLastTwo(permB) && useCount == 1) + { + auto originalInput = transposeB->getInputs(0); + matmulOp->replaceInput(inputB, originalInput); + originalInput.get()->addTarget(matmulOp); + originalInput.get()->removeTarget(sourceB); + matmulOp->setTransB(!matmulOp->getTransB()); + + // 清理其他op对transposeB的引用 + cleanupOpReferences(sourceB); + + removeTensor(inputB); + removeOperator(sourceB); + + changed = true; + continue; + } + } + } + } } Tensor GraphObj::getTensor(int fuid) const @@ -151,6 +343,95 @@ namespace infini // =================================== 作业 =================================== // TODO:利用 allocator 给计算图分配内存 // HINT: 获取分配好的内存指针后,可以调用 tensor 的 setDataBlob 函数给 tensor 绑定内存 + + // Track tensor lifetime: when a tensor is last used (by which operator index) + std::unordered_map tensorLastUse; + std::unordered_map tensorAddress; + + // Initialize: all tensors are used at least once initially (for outputs without targets) + for (auto &tensor : tensors) + { + tensorLastUse[tensor.get()] = -1; + } + + // Allocate memory for input tensors (tensors without source) first + for (auto &tensor : tensors) + { + if (!tensor->getSource()) + { + size_t offset = allocator.alloc(tensor->getBytes()); + tensorAddress[tensor.get()] = offset; + } + } + + // Calculate last use for each tensor based on the operators + for (size_t i = 0; i < ops.size(); ++i) + { + auto &op = ops[i]; + + // Check inputs - update their last use time + for (auto &input : op->getInputs()) + { + if (input) + { + tensorLastUse[input.get()] = i; + } + } + } + + // For output tensors that have no targets, they should live until the end + for (auto &tensor : tensors) + { + if (tensor->getTargets().size() == 0 && tensor->getSource()) + { + // This is a graph output, it should live until the end + tensorLastUse[tensor.get()] = ops.size(); + } + } + + // Process each operator in topological order + for (size_t i = 0; i < ops.size(); ++i) + { + auto &op = ops[i]; + + // Allocate memory for outputs + for (auto &output : op->getOutputs()) + { + if (output && tensorAddress.find(output.get()) == tensorAddress.end()) + { + size_t offset = allocator.alloc(output->getBytes()); + tensorAddress[output.get()] = offset; + } + } + + // Free inputs that are no longer needed after this operator + for (auto &input : op->getInputs()) + { + if (input && tensorLastUse[input.get()] == (int)i) + { + // This is the last use of this tensor + if (tensorAddress.find(input.get()) != tensorAddress.end()) + { + allocator.free(tensorAddress[input.get()], input->getBytes()); + } + } + } + } + + // Get the actual memory pointer from allocator + void *basePtr = allocator.getPtr(); + + // Bind memory to each tensor + for (auto &tensor : tensors) + { + if (tensorAddress.find(tensor.get()) != tensorAddress.end()) + { + size_t offset = tensorAddress[tensor.get()]; + void *tensorPtr = reinterpret_cast(basePtr) + offset; + auto blob = make_ref(runtime, tensorPtr); + tensor->setDataBlob(blob); + } + } // =================================== 作业 =================================== allocator.info(); diff --git a/src/operators/concat.cc b/src/operators/concat.cc index d196330..ed08abf 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -16,7 +16,30 @@ optional> ConcatObj::inferShape(const TensorVec &inputs) { // =================================== 作业 =================================== // TODO:修改 dims,返回正确的 concat 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Concat.html#concat-13 - // =================================== 作业 =================================== + + // All inputs should have the same shape except for the dimension being concatenated + // Sum up the sizes along the concatenation dimension + int concatDimSize = 0; + for (size_t i = 0; i < inputs.size(); ++i) { + auto inputDims = inputs[i]->getDims(); + + // Verify that all other dimensions match + if (inputDims.size() != dims.size()) { + return std::nullopt; + } + + for (size_t j = 0; j < dims.size(); ++j) { + if ((int)j != dim && inputDims[j] != dims[j]) { + return std::nullopt; + } + } + + // Accumulate the size of the concatenation dimension + concatDimSize += inputDims[dim]; + } + + // Update the output shape + dims[dim] = concatDimSize; return {{dims}}; } diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 7a16ca2..5ceb411 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -21,13 +21,61 @@ namespace infini return os.str(); } + + optional> MatmulObj::inferShape(const TensorVec &inputs) { - // =================================== 作业 =================================== - // TODO:返回经过 matmul 操作后的 shape - // REF: https://github.com/onnx/onnx/blob/main/docs/Operators.md#gemm - // =================================== 作业 =================================== - return std::nullopt; + // 矩阵乘法: C = A * B + // A 和 B 的最后两个维度进行矩阵乘法,前面的维度是批量维度 + const auto &A = inputs[0], &B = inputs[1]; + auto dimsA = A->getDims(); + auto dimsB = B->getDims(); + + int rankA = A->getRank(); + int rankB = B->getRank(); + + // 获取矩阵维度 + // transA/transB 只影响最后两个维度 + int rowA = transA ? dimsA[rankA - 1] : dimsA[rankA - 2]; + int colA = transA ? dimsA[rankA - 2] : dimsA[rankA - 1]; + int rowB = transB ? dimsB[rankB - 1] : dimsB[rankB - 2]; + int colB = transB ? dimsB[rankB - 2] : dimsB[rankB - 1]; + + // 保存 m, n, k 用于后续计算 + m = rowA; + k = colA; + n = colB; + + // 检查矩阵乘法维度匹配 + IT_ASSERT(colA == rowB, "Matrix dimensions mismatch for matmul"); + + // 计算批量维度(支持广播,右对齐) + int maxRank = std::max(rankA, rankB); + Shape outputShape(maxRank); + + // 批量维度广播(右对齐) + for (int i = 0; i < maxRank - 2; ++i) + { + // 右对齐:从右边对齐索引 + int idxA = i - (maxRank - rankA); + int idxB = i - (maxRank - rankB); + int dimA = (idxA >= 0 && idxA < rankA - 2) ? dimsA[idxA] : 1; + int dimB = (idxB >= 0 && idxB < rankB - 2) ? dimsB[idxB] : 1; + + // 广播规则检查:两个维度必须相等,或其中一个为1 + IT_ASSERT(dimA == dimB || dimA == 1 || dimB == 1, + "Batch dimensions cannot be broadcast"); + + outputShape[i] = std::max(dimA, dimB); + } + + // 设置输出的矩阵维度 + outputShape[maxRank - 2] = m; + outputShape[maxRank - 1] = n; + + return {{outputShape}}; } + + } // namespace infini \ No newline at end of file diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index faab2b6..1017eb3 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -32,9 +32,15 @@ namespace infini // =================================== 作业 =================================== // TODO:修改 output_dim,返回正确的 transpose 后的 shape // REF: https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-21 + + // Apply the permutation to get the output shape + for (int i = 0; i < rank; ++i) + { + output_dim[i] = input_dim[transposePermute[i]]; + } // =================================== 作业 =================================== - return std::nullopt; + return {{output_dim}}; } std::string TransposeObj::toString() const diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 3daad36..ac6d13d 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -35,11 +35,9 @@ namespace infini optional> ClipObj::inferShape(const TensorVec &inputs) { - // =================================== 作业 =================================== - // TODO:返回经过 clip 操作后的 shape - // REF: https://onnx.ai/onnx/operators/onnx__Clip.html#clip-13 - // =================================== 作业 =================================== - return std::nullopt; + // Clip 操作不改变 tensor 的 shape,只是将值限制在 [min, max] 范围内 + const auto A = inputs[0]; + return {{A->getDims()}}; } std::string ClipObj::toString() const @@ -61,21 +59,15 @@ namespace infini vector CastObj::inferDataType(const TensorVec &inputs) const { - // =================================== 作业 =================================== - // TODO:返回经过 cast 操作后, 输出 tensor 的数目和数据类型 - // REF_FILE: src/core/operator.cc - // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 - // =================================== 作业 =================================== - return {}; + // Cast 操作输出一个 tensor,数据类型由 castType 决定 + return {getOutputDataType()}; } optional> CastObj::inferShape(const TensorVec &inputs) { - // =================================== 作业 =================================== - // TODO:返回经过 cast 操作后的 shape - // REF: https://onnx.ai/onnx/operators/onnx__Cast.html#cast-21 - // =================================== 作业 =================================== - return std::nullopt; + // Cast 操作不改变 tensor 的 shape,只改变数据类型 + const auto input = inputs[0]; + return {{input->getDims()}}; } std::string CastObj::toString() const diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc index edbd2c8..a4cc27e 100644 --- a/src/utils/operator_utils.cc +++ b/src/utils/operator_utils.cc @@ -8,9 +8,47 @@ Shape infer_broadcast(const Shape &A, const Shape &B) { // =================================== 作业 =================================== // TODO:对 A 和 B 进行双向广播,返回广播后的形状。 // REF: https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md + + // Broadcasting rules: + // 1. If two shapes have different ranks, prepend 1s to the shorter one + // 2. For each dimension, the output dimension is max(dim_A, dim_B) + // 3. Dimensions are compatible if they are equal or one of them is 1 + + size_t rankA = A.size(); + size_t rankB = B.size(); + size_t maxRank = std::max(rankA, rankB); + + Shape result(maxRank); + + // Iterate from the trailing dimensions + for (size_t i = 0; i < maxRank; ++i) { + int dimA = 1, dimB = 1; + + // Get dimension from A (if exists) + if (i < rankA) { + dimA = A[rankA - 1 - i]; + } + + // Get dimension from B (if exists) + if (i < rankB) { + dimB = B[rankB - 1 - i]; + } + + // Check compatibility and compute output dimension + if (dimA == dimB) { + result[maxRank - 1 - i] = dimA; + } else if (dimA == 1) { + result[maxRank - 1 - i] = dimB; + } else if (dimB == 1) { + result[maxRank - 1 - i] = dimA; + } else { + // Incompatible dimensions + IT_ASSERT(false, "Incompatible broadcast dimensions"); + } + } // =================================== 作业 =================================== - return {}; + return result; } int get_real_axis(const int &axis, const int &rank) {