1. 思路小结
要优化你提供的稀疏矩阵乘法代码,我们可以引入CSR(压缩稀疏行)格式来避免遍历零元素,从而提高效率。CSR格式通过仅存储非零元素以及它们的行和列索引,可以有效减少稀疏矩阵计算时的时间复杂度。下面是对代码的优化版本,采用CSR格式进行稀疏矩阵的乘法:
优化步骤:
将稀疏矩阵转换为CSR格式,存储非零元素的位置和对应的值。
在矩阵乘法过程中,仅对非零元素进行计算,从而跳过零值。
对每一行的非零元素,在相应的列上执行乘法操作。
1.1 优化思路
进行的是两个稀疏矩阵的乘法。稀疏矩阵通常具有大量的零元素,因此直接使用常规矩阵乘法会导致大量的无效计算。为了提高效率,常用的优化方法是只对非零元素进行计算,而跳过零值。为此,我们采用**CSR(压缩稀疏行,Compressed Sparse Row)**格式进行稀疏矩阵存储和乘法计算。
1.1.1 核心步骤如下:
-
矩阵的稀疏表示:
- 原矩阵A和B可能有大量的零元素,因此我们采用CSR格式来存储这些矩阵。
- CSR格式由以下三个部分组成:
values[]
: 存储所有非零元素的值。colIndex[]
: 存储每个非零元素所在的列索引。rowPtr[]
: 记录每行的非零元素在values[]
中的起始位置。
-
矩阵的稀疏乘法:
- 对于矩阵A的每一行,我们找到其所有非零元素的位置及其值。
- 对于每一个非零元素,我们在矩阵B的相应列中查找与之匹配的非零元素。
- 最后将这些匹配的非零元素相乘,并累加到结果矩阵的对应位置。
-
优化:
- 通过CSR格式,避免了遍历和处理零元素,从而减少了不必要的计算。
- 我们直接对非零元素进行乘法运算,结果累积到结果矩阵C的对应位置。
1.2 算法复杂度分析
1.2.1 常规矩阵乘法的复杂度:
对于两个大小分别为 m x n
和 n x p
的矩阵,常规的矩阵乘法复杂度为O(m * n * p)。因为对于每一个 m x p
的结果元素,我们需要计算 n 次乘法操作。
1.2.2 稀疏矩阵乘法的复杂度:
由于稀疏矩阵大部分元素为零,我们只需要处理非零元素。假设矩阵A和矩阵B的非零元素分别为 nnzA
和 nnzB
,稀疏矩阵乘法的复杂度可以近似表示为:
- 对于每个非零元素
A[i][k]
,我们只需遍历矩阵B的第k列的非零元素进行乘法。因此稀疏矩阵乘法的复杂度大约为 O(nnzA * nnzB),其中nnzA
和nnzB
是矩阵A和矩阵B的非零元素数量。
这相比于常规矩阵乘法的复杂度有了显著的提升,尤其是当矩阵非常稀疏时(即大部分元素为0),非零元素的数量远小于矩阵的总大小。
1.2.3 空间复杂度:
使用CSR格式的空间复杂度为:
O(nnz)
:用于存储所有非零元素及其列索引。O(m)
:用于存储每一行的起始位置。
总体空间复杂度为 O(nnz + m),其中nnz
是矩阵的非零元素数量,m
是矩阵的行数。
1.2.4 总结
通过使用CSR格式存储稀疏矩阵,我们能够有效避免对零元素的计算,显著提升了稀疏矩阵乘法的计算效率。时间复杂度从常规的O(m * n * p)降低到接近于非零元素的数量 O(nnzA * nnzB)
,特别适合处理大规模稀疏矩阵的场景。
2. 优化后代码及其复杂度为
代码解析:
toCSR 函数:将普通的二维稀疏矩阵转换为CSR格式。values数组存储非零元素,colIndex存储每个非零元素的列索引,rowPtr则记录每行的非零元素在 values 数组中的起始位置。
multiplySparseMatricesCSR 函数:使用CSR格式进行矩阵乘法。通过 rowPtr 和 colIndex 来快速定位非零元素,避免了对零值的无效计算。
优化效果:
通过CSR格式存储非零元素,并跳过零元素的乘法操作,能够显著减少计算时间。
避免遍历零值,提高了计算效率,尤其在大规模稀疏矩阵的场景下。
#include <iostream>
#include <vector>using namespace std;// CSR格式的稀疏矩阵
struct CSRMatrix {vector<int> values; // 存储非零元素的值vector<int> colIndex; // 存储非零元素的列索引vector<int> rowPtr; // 每一行的开始位置
};// 将稀疏矩阵转换为CSR格式
CSRMatrix toCSR(const vector<vector<int>>& matrix) {CSRMatrix csr;int row = matrix.size();int col = matrix[0].size();csr.rowPtr.push_back(0); // 第一行的开始位置是0// 遍历矩阵,收集非零元素的信息for (int i = 0; i < row; i++) {for (int j = 0; j < col; j++) {if (matrix[i][j] != 0) {csr.values.push_back(matrix[i][j]);csr.colIndex.push_back(j);}}csr.rowPtr.push_back(csr.values.size()); // 记录下一行的开始位置}return csr;
}// 使用CSR格式进行稀疏矩阵乘法
vector<vector<int>> multiplySparseMatricesCSR(const CSRMatrix& A, const CSRMatrix& B, int colB) {int rowA = A.rowPtr.size() - 1;vector<vector<int>> C(rowA, vector<int>(colB, 0)); // 初始化结果矩阵// 遍历A的每一行for (int i = 0; i < rowA; i++) {// A的第i行的非零元素从A.rowPtr[i]到A.rowPtr[i+1]-1for (int aPos = A.rowPtr[i]; aPos < A.rowPtr[i+1]; aPos++) {int colA = A.colIndex[aPos]; // 该非零元素所在的列int aValue = A.values[aPos]; // 非零元素的值// 对应B的第colA行for (int j = B.rowPtr[colA]; j < B.rowPtr[colA+1]; j++) {int colBIndex = B.colIndex[j];int bValue = B.values[j];C[i][colBIndex] += aValue * bValue;}}}return C;
}int main() {// 定义稀疏矩阵Avector<vector<int>> A = {{1, 0, 0},{-1, 0, 3}};// 定义稀疏矩阵Bvector<vector<int>> B = {{7, 0, 0},{0, 0, 0},{0, 0, 1}};// 将矩阵A和B转换为CSR格式CSRMatrix csrA = toCSR(A);CSRMatrix csrB = toCSR(B);// 计算A和B的乘积vector<vector<int>> C = multiplySparseMatricesCSR(csrA, csrB, B[0].size());// 输出结果矩阵cout << "Result of A * B:" << endl;for (const auto& row : C) {for (int elem : row) {cout << elem << " ";}cout << endl;}return 0;
}
3. 优化前原始代码及其复杂度为O(m * n * p),这里是最朴素的思路,没有利用稀疏特性做任何优化
#include <iostream>
#include <vector>using namespace std;// 定义稀疏矩阵乘法函数
vector<vector<int>> multiplySparseMatrices(vector<vector<int>>& A, vector<vector<int>>& B) {int rowA = A.size();int colA = A[0].size();int rowB = B.size();int colB = B[0].size();// 初始化结果矩阵,大小为rowA * colBvector<vector<int>> C(rowA, vector<int>(colB, 0));// 遍历矩阵A的每一行for (int i = 0; i < rowA; i++) {// 遍历矩阵A的每个列,寻找非零元素for (int k = 0; k < colA; k++) {if (A[i][k] != 0) {// 当A的某个位置非零时,计算该元素和矩阵B的第k行for (int j = 0; j < colB; j++) {if (B[k][j] != 0) {C[i][j] += A[i][k] * B[k][j];}}}}}return C;
}int main() {// 定义稀疏矩阵Avector<vector<int>> A = {{1, 0, 0},{-1, 0, 3}};// 定义稀疏矩阵Bvector<vector<int>> B = {{7, 0, 0},{0, 0, 0},{0, 0, 1}};// 计算A和B的乘积vector<vector<int>> C = multiplySparseMatrices(A, B);// 输出结果矩阵cout << "Result of A * B:" << endl;for (const auto& row : C) {for (int elem : row) {cout << elem << " ";}cout << endl;}return 0;
}