I recently started a new project called Blaze, a set of hand-written CUDA kernels targeting the NVIDIA B200 GPU (SM100 / Blackwell) for LLM inference. No frameworks, no wrappers — just CUDA C++ with inline PTX for Blackwell’s new tcgen05 instructions. The goal is end-to-end Llama-7B text generation using only custom kernels that exploit Blackwell’s new hardware primitives: tcgen05 tensor cores, Tensor Memory (TMEM), and the Tensor Memory Accelerator (TMA).

This post covers Phase 0: hardware bringup on a B200 GPU.


Why Build From Scratch?

Blackwell introduced a fundamentally new tensor core programming model. Previous generations (Ampere, Hopper) used wmma or wgmma instructions where warps or warp groups collectively owned matrix operations. Blackwell’s tcgen05 instructions are different — a single thread dispatches an MMA that the hardware executes asynchronously, with accumulators living in a dedicated on-chip memory (TMEM) rather than in registers. This decouples compute from data management in a way that wasn’t possible before.

Existing inference engines are adapting Hopper-era kernels to Blackwell. Blaze starts from scratch on SM100 to take full advantage of these new primitives.

What’s New in Blackwell for Kernels

A quick summary of the hardware features Blaze targets:

  • Tensor Memory (TMEM): 256 KB per SM of dedicated accumulator storage. MMA results live here instead of in registers, freeing the register file for softmax, scaling, and other CUDA-core work that can run in parallel with tensor core operations.
  • tcgen05.mma: The new tensor core instruction. A single thread issues it, the hardware executes an entire tile’s multiply-accumulate asynchronously, and results land in TMEM. This enables deeper pipelines with lower register pressure than Hopper’s warp-group model.
  • TMA (Tensor Memory Accelerator): Hardware engine for async bulk loads from global memory directly into shared memory, bypassing the L1 cache entirely. Supports 2D tensor addressing with swizzling.
  • Native FP4: E2M1 precision with hardware dequantization and block scaling, executed directly on tensor cores. 2x memory savings over FP8.

Phase 0: How can We Actually Run tcgen05 on B200?

Phase 0 answers two basic questions before writing any production kernels:

  1. How can we execute tcgen05 PTX instructions on the B200?
  2. Is the matrix indexing correct?

Experiment 1: TMEM Lifecycle

The first kernel (hello_tcgen05.cu) does nothing but allocate and deallocate TMEM:

tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [smem_addr], num_columns;
tcgen05.dealloc.cta_group::1.sync.aligned.b32 tmem_addr, num_columns;

This sounds trivial, but getting it to run on real hardware surfaced several constraints that aren’t obvious from the PTX documentation:

  • Minimum 32 TMEM columns. tcgen05.alloc requires num_columns >= 32 (power of 2, max 512). Anything smaller causes the hardware to trap at a SASS-level UTCATOMSWS instruction with “illegal instruction”.
  • Cluster launch is mandatory. Even a single-CTA kernel using cta_group::1 requires the cluster scheduling infrastructure. The standard <<<grid, block>>> launch syntax doesn’t initialize the cluster context — you need cudaLaunchKernelEx with cudaLaunchAttributeClusterDimension and a __cluster_dims__ kernel attribute. Without this, the kernel compiles fine but traps at runtime.
  • Alloc and dealloc are warp-collective. Both instructions are .sync.aligned, meaning all 32 threads in the warp must reach them together. Guarding with if (threadIdx.x == 0) instead of if (warp_id == 0) causes a silent deadlock — the SM hangs indefinitely with no error or timeout.

The kernel compiles for sm_100a (the a suffix is required to enable architecture-accelerated features) and runs successfully:

Experiment 2: Naive FP16 GEMM

The second kernel (gemm_fp16_naive.cu) computes a 4096x4096 FP16 matrix multiply using CUDA cores (not tensor cores), while exercising TMEM alloc/dealloc across 1024 concurrent CTAs. It validates correctness against cuBLAS.

The kernel uses a standard tiled approach: 128x128 output tiles, 16-wide K-tiles loaded cooperatively into shared memory, FP32 accumulators in registers, final output narrowed to FP16. Each CTA also allocates and deallocates 32 TMEM columns to stress-test the lifecycle under real workload conditions.

Performance:
  Average time: 44.054 ms
  Throughput:   3.12 TFLOPS (FP16, CUDA cores)

Correctness vs cuBLAS:
  Max absolute error:  0.750000
  Mean absolute error: 0.035081
  Mean relative error: 0.003848
  Mismatched elements: 0 / 16777216
  Status: PASS

The 0.4% mean relative error confirms correctness. The max errors are expected — FP16 arithmetic is not associative, so two correct implementations with different accumulation orders produce different results for near-zero elements.

Profiling: CUDA Cores vs Tensor Cores

The most interesting part of Phase 0 is profiling the naive kernel against cuBLAS with Nsight Compute using SM100-native metrics.

cuBLAS on B200 dispatches a CUTLASS SM100 kernel (cutlass3x_sm100_tensorop_h256x256x16gemm_..._2sm) that uses tcgen05.mma with 2-SM cooperative tiles and TMA for all data movement. Here’s what the profiles look like side by side:

MetricNaive GEMMcuBLASRatio
Duration16.91 ms178.40 us95x
Instructions executed6.0 billion3.9 million1520x
IPC2.340.220.09x
Highest pipelineFMA (42.1%)TMEM (75.8%)
SMEM/block8.32 KB231.42 KB28x
L1 global loads50.3M0

A few observations:

  • 1520x fewer instructions. Each tcgen05.mma retires an entire tile of multiply-accumulates in one instruction. The naive kernel issues individual FMA instructions per element pair. This ratio alone captures why tensor cores exist.

  • TMEM is the bottleneck for cuBLAS, and that’s expected. Nsight Compute’s full profile identifies TMEM as the highest-utilized pipeline at 75.8%, encompassing UTCMMA, LDT/STT, UTCCP, and UTCSHIFT operations. Being bottlenecked on the tensor memory pipeline means compute is the limiter, not data movement — the optimal operating point for a GEMM kernel.

  • The IPC paradox. Our kernel has 10x higher IPC (2.34 vs 0.22 instructions per cycle), yet runs 95x slower. cuBLAS’s low IPC reflects that each issued instruction does orders of magnitude more work. High IPC on scalar instructions doesn’t compensate for hardware-accelerated matrix operations.

  • TMA at 1%. cuBLAS loads all data through TMA (zero L1 global loads), but the TMA pipeline barely registers at 1% utilization because loads complete asynchronously and overlap entirely with MMA compute. The naive kernel routes 50 million accesses through L1 via explicit loads.

  • cuBLAS uses 231 KB of shared memory per block. Each 2-SM cluster consumes nearly the full 256 KB available, limiting occupancy to 1 block per SM. Despite only 12.5% theoretical occupancy, the kernel achieves 75.8% pipeline utilization through deep instruction-level parallelism within the few active warps. Low occupancy is intentional and expected for GEMM.

  • Register spills dominate the naive kernel. The full profile reveals that 94% of our kernel’s L1 traffic is local memory — the 128 FP32 accumulators per thread spill from registers to thread-private global memory. Each local load/store uses only 1 of 32 bytes per sector, a 32x bandwidth waste from strided access patterns. This is the dominant bottleneck, not the CUDA-core FMA throughput.

What’s Next

Phase 0 established that the tcgen05 pipeline works end-to-end on real hardware and gave us a cuBLAS profile to benchmark against. The 76% TC pipe utilization is the target to match.

Phase 1 moves from CUDA cores to tcgen05.mma for the actual GEMM kernels that will drive inference.

The full Phase 0 report with detailed profiling data is at docs/phase0.md.