Collective Communications

Collective Communications

The <collectives_2d> library can be used for communication between PEs in the same row or column. It mimics the capabilities provided by message passing interface (MPI) collective operations found in other programming languages.

This example showcases each of the currently available communication primitives while using the library across two indepedent dimensions. The communication tasks are executed asynchronously.

task_x uses the broadcast primitive to transmit data from the first PE in every row to every other PE in the same row. After the data is received, reduce_fadds computes the vector sum of the broadcast_recv. The result is transmitted back to the first PE in every row.

task_y operates concurrently along every column of PEs. The task first uses scatter to distribute chunk_size slices of scatter_data across every other PE in the same column. The task uses gather to collect chunk_size slices of data distributed by scatter. Because scatter is the inversion of gather, we have used collective communications to transmit the data from scatter_data to gather_recv.

layout.csl

// resources to route the data between the host and the device.
//

// color map
//
// color  var    color  var              color  var              color  var
//   0 x.c0         9  LAUNCH              18                      27   reserved (memcpy)
//   1 x.c1        10                      19                      28   reserved (memcpy)
//   2 x.c2        11                      20                      29   reserved
//   3 x.c3        12                      21    reserved (memcpy) 30   reserved (memcpy)
//   4 y.c0        13                      22    reserved (memcpy) 31   reserved
//   5 y.c1        14                      23    reserved (memcpy) 32
//   6 y.c2        15  x_color             24                      33
//   7 y.c3        16  y_color             25                      34
//   8             17                      26                      35
//

param Pw: u16;
param Ph: u16;
param chunk_size: u16;

param LAUNCH_ID: i16;

const LAUNCH: color = @get_color(LAUNCH_ID);

const x_color = @get_color(15);
const y_color = @get_color(16);

const c2d = @import_module("<collectives_2d/params>");

const memcpy = @import_module( "<memcpy_multi/get_params>", .{
    .width = Pw,
    .height = Ph
    });

layout {
  @set_rectangle(Pw, Ph);

  var Px: u16 = 0;
  while (Px < Pw) : (Px += 1) {
    var Py: u16 = 0;
    while (Py < Ph) : (Py += 1) {
      const params = c2d.get_params(Px, Py, .{
        .x_colors = .{ @get_color(0), @get_color(1), @get_color(2), @get_color(3) },
        .y_colors = .{ @get_color(4), @get_color(5), @get_color(6), @get_color(7) },
      });
      const memcpy_params = memcpy.get_params(Px);
      @set_tile_code(Px, Py, "code.csl", .{
        .x_color = x_color,
        .y_color = y_color,
        .memcpy_params = memcpy_params,
        .LAUNCH = LAUNCH,
        .c2d_params = params,
        .chunk_size = chunk_size });
    }
  }

  // export symbol name
  @export_name("broadcast_data", [*]u32, true);
  @export_name("scatter_data", [*]u32, true);
  @export_name("broadcast_recv", [*]u32, true);
  @export_name("faddh_result", [*]u32, true);
  @export_name("gather_recv", [*]u32, true);

  @export_name("f_run_x", fn()void);
  @export_name("f_run_y", fn()void);
}

code.csl

param c2d_params: comptime_struct;
param chunk_size: u16;

param x_color: color;
param y_color: color;

param memcpy_params: comptime_struct;
param LAUNCH: color;

const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
     .LAUNCH = LAUNCH
      }));

const rect_height = @get_rectangle().height;
const rect_width = @get_rectangle().width;

const mpi_x = @import_module("<collectives_2d/pe>", .{
    .dim_params = c2d_params.x,
    .queues = [2]u16{2,4},
    .dest_dsr_ids = [1]u16{1},
    .src0_dsr_ids = [1]u16{1},
    .src1_dsr_ids = [1]u16{1}
    });
const mpi_y = @import_module("<collectives_2d/pe>", .{
    .dim_params = c2d_params.y,
    .queues = [2]u16{3,5},
    .dest_dsr_ids = [1]u16{2},
    .src0_dsr_ids = [1]u16{2},
    .src1_dsr_ids = [1]u16{2}
    });


const Nx = chunk_size * rect_width;
const Ny = chunk_size * rect_height;

// broadcast_data and scatter_data supplied by run.py
var broadcast_data = @zeros([Nx]u32);
var broadcast_recv = @zeros([Nx]u32);
var faddh_result = @zeros([Nx]u32);

var scatter_data = @zeros([Ny]u32);
var scatter_recv = @zeros([Ny]u32);
var gather_recv = @zeros([Ny]u32);

var ptr_broadcast_data: [*]u32 = &broadcast_data;
var ptr_scatter_data: [*]u32 = &scatter_data;
var ptr_broadcast_recv: [*]u32 = &broadcast_recv;
var ptr_faddh_result: [*]u32 = &faddh_result;
var ptr_gather_recv: [*]u32 = &gather_recv;

var task_x_state: u16 = 0;
task task_x() void {
   switch (task_x_state) {
      0 => {
         mpi_x.init();
         var send_buf = @ptrcast([*]u32, &broadcast_data);
         var recv_buf = @ptrcast([*]u32, &broadcast_recv);
         if (mpi_x.pe_id == 0) {
            mpi_x.broadcast(0, send_buf, Nx, x_color);
         } else {
            mpi_x.broadcast(0, recv_buf, Nx, x_color);
         }

         task_x_state += 1;
      },
      1 => {
         var send_buf = @ptrcast([*]f32, &broadcast_recv);
         var recv_buf = @ptrcast([*]f32, &faddh_result);

         mpi_x.reduce_fadds(0, send_buf, recv_buf, Nx, x_color);

         task_x_state += 1;
      },
      else => {
         // WARNING: the user must unblock cmd color for every PE
         sys_mod.unblock_cmd_stream();
         return;
      }
   }
}

var task_y_state: u16 = 0;
task task_y() void {
   switch (task_y_state) {
      0 => {
         mpi_y.init();
         var send_buf = @ptrcast([*]u32, &scatter_data);
         var recv_buf = @ptrcast([*]u32, &scatter_recv);

         mpi_y.scatter(0, send_buf, recv_buf, chunk_size, y_color);

         task_y_state += 1;
      },
      1 => {
         var send_buf = @ptrcast([*]u32, &scatter_recv);
         var recv_buf = @ptrcast([*]u32, &gather_recv);

         mpi_y.gather(0, send_buf, recv_buf, chunk_size, y_color);

         task_y_state += 1;
      },
      else => {
         // WARNING: the user must unblock cmd color for every PE
         sys_mod.unblock_cmd_stream();
         return;
      }
   }
}

comptime {
   @bind_task(task_x, x_color);
   @bind_task(task_y, y_color);

}

fn f_run_x() void {
   @activate(x_color);

   // terminate when task_x finishes
}

fn f_run_y() void {
   @activate(y_color);

   // terminate when task_x finishes
}

comptime{
  @export_symbol(ptr_broadcast_data, "broadcast_data");
  @export_symbol(ptr_scatter_data, "scatter_data");
  @export_symbol(ptr_broadcast_recv, "broadcast_recv");
  @export_symbol(ptr_faddh_result, "faddh_result");
  @export_symbol(ptr_gather_recv, "gather_recv");
  @export_symbol(f_run_x);
  @export_symbol(f_run_y);
  @rpc(LAUNCH);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument('--cmaddr', help='IP:port for CS system')
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
Pw = int(params["Pw"])
Ph = int(params["Ph"])
chunk_size = int(params["chunk_size"])
print(f"Pw = width of the core = {Pw}")
print(f"Ph = height of the core = {Ph}")
print(f"chunk_size = {chunk_size}")

Nx = Pw*chunk_size
Ny = Ph*chunk_size

print(f"Nx = {Nx}, Ny = {Ny}")

memcpy_dtype = MemcpyDataType.MEMCPY_32BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

sym_broadcast_data = runner.get_id("broadcast_data")
sym_scatter_data = runner.get_id("scatter_data")
sym_broadcast_recv = runner.get_id("broadcast_recv")
sym_faddh_result = runner.get_id("faddh_result")
sym_gather_recv = runner.get_id("gather_recv")

runner.load()
runner.run()

print("step 1: copy mode H2D(broadcast_data) to 1st column PEs")
broadcast_data = np.ones((Ph, 1, Nx)).astype(np.float32)
runner.memcpy_h2d(sym_broadcast_data, broadcast_data.ravel(), 0, 0, 1, Ph, Nx, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)

print("step 2: copy mode H2D(scatter_data) to 1st row PEs")
scatter_data = np.ones((1, Pw, Ny)).astype(np.int32)
runner.memcpy_h2d(sym_scatter_data, scatter_data.ravel(), 0, 0, Pw, 1, Ny, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)

print("step 3: call f_run_x to test broadcast and reduction")
runner.launch("f_run_x", nonblock=False)

print("step 4: call f_run_y to test scatter and gather")
runner.launch("f_run_y", nonblock=False)

print("step 5: copy mode D2H(broadcast_recv)")
# broadcast on x: Px=0 broadcasts data to all other PEs
# broadcast_recv(y, x=0) = 0
# broadcast_recv(y, x !=0) = ones
broadcast_recv_1d = np.zeros(Ph*Pw*Nx, np.float32)
runner.memcpy_d2h(broadcast_recv_1d, sym_broadcast_recv, 0, 0, Pw, Ph, Nx, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
broadcast_recv = broadcast_recv_1d.reshape((Ph, Pw, Nx))

print("step 6: copy mode D2H(faddh_result) from 1st column PEs")
# reduce(broadcast_recv) to Px=0
faddh_result_1d = np.zeros(Ph*Nx, np.float32)
runner.memcpy_d2h(faddh_result_1d, sym_faddh_result, 0, 0, 1, Ph, Nx, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
faddh_result = faddh_result_1d.reshape((Ph, 1, Nx))

print("step 7: copy mode D2H(gather_recv) from 1st row PEs")
gather_recv_1d = np.zeros(Pw*Ny, np.int32)
runner.memcpy_d2h(gather_recv_1d, sym_gather_recv, 0, 0, Pw, 1, Ny, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
gather_recv = gather_recv_1d.reshape((1, Pw, Ny))

runner.stop()

# verify broadcast on x-direction
correct_broadcast_recv = np.ones(Nx).astype(np.float32)
for y in range(Ph):
  for x in range(Pw):
    if x == 0:
      continue
    np.testing.assert_equal(broadcast_recv[y, x], correct_broadcast_recv)

# verify faddh_result at 1st column PEs
# reduce on x: reduce(broadcast_recvs) to Px=0
# where broadcast_recvs(y, x=0) = 0
#       broadcast_recvs(y, x != 0) = ones
correct_faddh_result = np.full(Nx, (Pw-1), dtype=np.float32)
for y in range(Ph):
  np.testing.assert_equal(faddh_result[y, 0], correct_faddh_result)

# verify gather_recv at 1st row PEs
correct_gather_recv = np.ones(Ny).astype(np.int32)
for x in range(Pw):
  np.testing.assert_equal(gather_recv[0, x], correct_gather_recv)

print("SUCCESS")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./layout.csl --fabric-dims=22,17 --fabric-offsets=4,1 \
--params=Pw:15,Ph:15,chunk_size:3 -o out \
--params=LAUNCH_ID:9 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out