GEMM with Collective Operations
Contents
GEMM with Collective Operations¶
This program implements the SUMMA matrix multiplication algorithm and serves
as an example of using the collectives_2d
library together with
SdkRuntime
and the memcpy
framework.
The host code first copies tiles of A
and B
onto their corresponding
PEs. It then uses the remote procedure call (RPC) mechanism to launch the
function main
, at which point the GEMM computation begins.
We perform GEMM in P
many steps on a grid of P x P
processors.
At each step i
, PEs in the ith column broadcast their home tiles of A
to other PEs in their row, and PEs in the ith row broadcast their home
tiles of B
to other PEs in their column. Once both broadcasts are complete
as determined by x_done()
and y_done()
both being activated,
each PE computes C_tile += Ap * Bp
where Ap
and Bp
are pointers to
either the PE’s home tile or the tile it received through broadcasts.
When computation is complete the host copies back the resulting tiles of
C
from the device.
layout.csl¶
// Program rectangle is P x P
param P : u16;
// Matrix dimensions on one PE
param Mt : u16;
param Kt : u16;
param Nt : u16;
const LAUNCH : color = @get_color(8);
const EXIT : color = @get_color(9);
const memcpy = @import_module( "<memcpy_multi/get_params>", .{
.width = P,
.height = P
});
const c2d = @import_module("<collectives_2d/params>");
layout {
@set_rectangle(P, P);
var Px: u16 = 0;
while (Px < P) : (Px += 1) {
var Py: u16 = 0;
const memcpy_params = memcpy.get_params(Px);
while (Py < P) : (Py += 1) {
const c2d_params = c2d.get_params(Px, Py, .{
.x_colors = .{ @get_color(0), @get_color(1), @get_color(2), @get_color(3) },
.y_colors = .{ @get_color(4), @get_color(5), @get_color(6), @get_color(7) },
});
@set_tile_code(Px, Py, "pe.csl", .{ .c2d_params = c2d_params, .memcpy_params = memcpy_params,
.Mt = Mt, .Kt = Kt, .Nt = Nt, .LAUNCH = LAUNCH, .EXIT = EXIT });
}
}
// export symbol names
@export_name("A", [*]f32, true);
@export_name("B", [*]f32, true);
@export_name("C", [*]f32, true);
@export_name("main", fn()void);
}
pe.csl¶
// This program implements the SUMMA matrix multiplication algorithm and is
// written as an example to show how to use the `collectives_2d` library.
// We perform GEMM in `P` many steps on a grid of `P x P` processors.
// At each step `i`, PEs in the `i`th column broadcast their home tiles of `A`
// to other PEs in their row, and PEs in the `i`th row broadcast their home
// tiles of `B` to other PEs in their column. Once both broadcasts are complete
// as determined by `x_done()` and `y_done()` both being activated,
// each PE computes `C_tile += Ap * Bp` where `Ap` and `Bp` are pointers to
// either the PE's home tile or the tile it received through broadcasts.
param c2d_params: comptime_struct;
const mpi_x = @import_module("<collectives_2d/pe>", .{
.dim_params = c2d_params.x,
.queues = [2]u16{2,4},
.dest_dsr_ids = [1]u16{1},
.src0_dsr_ids = [1]u16{1},
.src1_dsr_ids = [1]u16{1}
});
const mpi_y = @import_module("<collectives_2d/pe>", .{
.dim_params = c2d_params.y,
.queues = [2]u16{3,5},
.dest_dsr_ids = [1]u16{2},
.src0_dsr_ids = [1]u16{2},
.src1_dsr_ids = [1]u16{2}
});
param memcpy_params: comptime_struct;
// Task colors
const compute_color = @get_color(12);
const x_color = @get_color(14);
const y_color = @get_color(15);
// Matrix size params
param Mt: i16;
param Kt: i16;
param Nt: i16;
param LAUNCH : color; // a routable color for RPC
param EXIT: color; // entrypoint to leave RPC
// memcpy uses input/output queue 0
const sys_mod = @import_module("<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
.LAUNCH = LAUNCH
}));
const P = @get_rectangle().width;
// This PE's home tile of A, B, C
// `A_tile` and `B_tile` will be populated with initial values by run.py
// These arrays are stored in a column major format.
var A_tile = @zeros([Mt*Kt]f32);
var B_tile = @zeros([Kt*Nt]f32);
var C_tile = @zeros([Mt*Nt]f32);
var ptr_A : [*]f32 = &A_tile;
var ptr_B : [*]f32 = &B_tile;
var ptr_C : [*]f32 = &C_tile;
// Temporary buffers for storing in-flight tiles of A and B
var A_buffer = @zeros([Mt*Kt]f32);
var B_buffer = @zeros([Kt*Nt]f32);
var px: u16;
var py: u16;
task x_done() void {
@activate(compute_color);
}
task y_done() void {
@unblock(compute_color);
}
var step: u16 = 0;
fn main() void {
@assert(step < P);
// The first time through we need to initialize our state
if (step == 0) {
mpi_x.init();
mpi_y.init();
px = mpi_x.pe_id;
py = mpi_y.pe_id;
}
// Communicate along both rows and columns
const Ap = if (px == step) &A_tile else &A_buffer;
const Bp = if (py == step) &B_tile else &B_buffer;
mpi_x.broadcast(step, @ptrcast([*]u32, Ap), Mt * Kt, x_color);
mpi_y.broadcast(step, @ptrcast([*]u32, Bp), Kt * Nt, y_color);
}
task compute() void {
const Ap = if (px == step) &A_tile else &A_buffer;
const Bp = if (py == step) &B_tile else &B_buffer;
// Do an fmacs based local GEMM
var A_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{Mt} -> A_tile[i] });
A_dsd = @set_dsd_base_addr(A_dsd, Ap);
for (@range(i16, Kt)) |k| {
var C_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{Mt} -> C_tile[i] });
for (@range(i16, Nt)) |j| {
const b = Bp.*[j*Kt + k];
@fmacs(C_dsd, C_dsd, A_dsd, b);
C_dsd = @increment_dsd_offset(C_dsd, Mt, f32);
}
A_dsd = @increment_dsd_offset(A_dsd, Mt, f32);
}
step += 1;
@block(compute_color);
if (step != P) {
main();
} else {
@activate(EXIT);
}
}
task f_exit() void {
// the user must unblock cmd color for every PE
sys_mod.unblock_cmd_stream();
}
comptime {
@bind_task(f_exit, EXIT);
@bind_task(compute, compute_color);
@bind_task(x_done, x_color);
@bind_task(y_done, y_color);
@block(compute_color);
@export_symbol(ptr_A, "A");
@export_symbol(ptr_B, "B");
@export_symbol(ptr_C, "C");
@export_symbol(main);
@rpc(LAUNCH);
}
run.py¶
#!/usr/bin/env cs_python
import argparse
import json
import numpy as np
from cerebras.sdk.runtime import runtime_utils
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module
parser = argparse.ArgumentParser()
parser.add_argument("--name", help="the test name")
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
# Get params from compile metadata
with open(f"{args.name}/out.json", encoding='utf-8') as json_file:
compile_data = json.load(json_file)
# Kernel rectangle and per-PE matrix dimensions
P = int(compile_data['params']['P'])
Mt = int(compile_data['params']['Mt'])
Kt = int(compile_data['params']['Kt'])
Nt = int(compile_data['params']['Nt'])
# Full matrix dimensions
# A is M x K, B is K x N, C is M x N
M = Mt * P
K = Kt * P
N = Nt * P
memcpy_dtype = MemcpyDataType.MEMCPY_32BIT
memcpy_order = MemcpyOrder.ROW_MAJOR
A = np.arange(M * K, dtype=np.float32).reshape((M, K))
B = np.arange(K * N, dtype=np.float32).reshape((K, N))
simulator = SdkRuntime(args.name, cmaddr=args.cmaddr)
symbol_A = simulator.get_id("A")
symbol_B = simulator.get_id("B")
symbol_C = simulator.get_id("C")
simulator.load()
simulator.run()
iportmap_A = f"{{ A[j=0:{M-1}][i=0:{K-1}] -> [PE[i//{Kt}, j//{Mt}] -> \
index[i%{Kt}, j%{Mt}]] }}"
(px, py, w, h, l, data) = runtime_utils.convert_input_tensor(iportmap_A, A)
simulator.memcpy_h2d(symbol_A, data, px, py, w, h, l,
streaming=False, data_type=memcpy_dtype, nonblock=False,
order=memcpy_order)
iportmap_B = f"{{ B[j=0:{K-1}][i=0:{N-1}] -> [PE[i//{Nt}, j//{Kt}] -> \
index[i%{Nt}, j%{Kt}]] }}"
(px, py, w, h, l, data) = runtime_utils.convert_input_tensor(iportmap_B, B)
simulator.memcpy_h2d(symbol_B, data, px, py, w, h, l,
streaming=False, data_type=memcpy_dtype, nonblock=False,
order=memcpy_order)
simulator.call("main", [], nonblock=False)
oportmap_C = f"{{ C[j=0:{M-1}][i=0:{N-1}] -> [PE[i//{Nt}, j//{Mt}] -> \
index[i%{Nt}, j%{Mt}]] }}"
(px, py, w, h, l, data) = runtime_utils.prepare_output_tensor(oportmap_C, np.float32)
simulator.memcpy_d2h(data, symbol_C, px, py, w, h, l,
streaming=False, data_type=memcpy_dtype, nonblock=False,
order=memcpy_order)
C = runtime_utils.format_output_tensor(oportmap_C, np.float32, data)
simulator.stop()
# Check the result
C_expected = np.dot(A, B)
np.testing.assert_equal(C_expected, C)
print("SUCCESS")
commands.sh¶
#!/usr/bin/env bash
set -e
cslc ./layout.csl --fabric-dims=11,6 --fabric-offsets=4,1 \
--params=P:4,Mt:14,Kt:14,Nt:14 \
--memcpy --channels=1 -o out
cs_python run.py --name out