Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Openmlsys 8.4 章节中 gemm_use_tile.cu 的代码问题 #2

Open
hiha3456 opened this issue May 29, 2023 · 1 comment
Open

Openmlsys 8.4 章节中 gemm_use_tile.cu 的代码问题 #2

hiha3456 opened this issue May 29, 2023 · 1 comment

Comments

@hiha3456
Copy link

这一部分有关三个 Layout 的代码我一直没有看明白,在这篇知乎中看到一样的内容后,我发现这里的代码实现可能有 bug。
按照定义:

  • LayoutTile 是每个 Block 有 (LayoutTile::m, LayoutTile::n) 个 float
  • LayoutBlock 是每个 Block 有 (LayoutBlock::m, LayoutBlock::n) 个 thread
  • LayoutThread 是每个 thread 中,每个 submatrix 有 (LayoutThread::m, LayoutThread::n) 个 float, 因为用的 float4,可以理解为 4*4。

那么此处 gemm_use_tile.cu 第10行和第11行 中对于m 和 n 的定义就有问题了,应该如下:

unsigned m= threadIdx.x* LayoutTile::m/LayoutBlock::m+ LayoutTile::m* blockIdx.x;
unsigned n= threadIdx.y* LayoutTile::n/LayoutBlock::n+ LayoutTile::n* blockIdx.y

同样的, gemm_use_tile.cu 第19行和第20行 中,iterationA 和 iterationB 应该分别指的是每个 thread 有多少个 (4,4) 的 subMatrix,这里应该是 2*2 = 4 个,那么 gemm_use_tile.cu 第21行和第22行 intervalA 和 intervalB 的定义就有问题了,按照后续代码,intervalA 和 intervalB 指的分别应该是每个 subMatrix 有多大,也就是 (LayoutThread::m, LayoutThread::n)

@weizhenhuan
Copy link

兄弟,我也一直觉得他这里的m和n有问题,我没理解错的话,如果是每个线程算4个4*4,那么LayoutTile是(8,8),LayoutThread是(4,4),但是这个LayoutBlock我一直不明白啥意思,每个block不是只有1个线程吗?
另外,m和n的计算最后 是不是应该还要乘一个blockDim.x,blockDim.y?亦即

unsigned m= threadIdx.x* LayoutTile::m/LayoutBlock::m+ LayoutTile::m* blockIdx.x * blockDim.x;
unsigned n= threadIdx.y* LayoutTile::n/LayoutBlock::n+ LayoutTile::n* blockIdx.y * blockDim.y;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants