Sentinel Colors

In previous programs, we used so-called routable colors, which are colors that are associated with a route to direct the flow of wavelets. Routable colors are in the range 0 through 23. This example demonstrates the use of non-routable color to signal the end of a input tensor, thus giving it the name Sentinel Color.

In this example, the host sends the number of wavelets via RPC. The kernel counts received wavelets via H2D, and trigger the sentinel color when all wavelets are received. Since sentinel colors are not routable, the programmer should not specify a route for them, but they do need to bind the sentinel color to a task.

Here, the sentinel color activates the send_result task, which relays the result of the sum reduction back to the host.

code.csl

// resources to route the data between the host and the device.
//

// color map
//
// color  var        color  var          color  var              color  var
//   0 H2D              9                  18                      27   reserved (memcpy)
//   1 D2H             10                  19                      28   reserved (memcpy)
//   2 LAUNCH          11                  20                      29   reserved
//   3                 12                  21    reserved (memcpy) 30   reserved (memcpy)
//   4                 13                  22    reserved (memcpy) 31   reserved
//   5                 14                  23    reserved (memcpy) 32
//   6                 15                  24                      33
//   7                 16                  25                      34
//   8 end_computation 17                  26                      35
//

param MEMCPYH2D_DATA_1_ID: i16;
param MEMCPYD2H_DATA_1_ID: i16;
param LAUNCH_ID: i16;
// number of PEs in a column
param size: i16;

const MEMCPYH2D_DATA_1: color = @get_color(MEMCPYH2D_DATA_1_ID);
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const LAUNCH: color = @get_color(LAUNCH_ID);

// Entrypoint to tell PE that it is time to send the result to the host
const end_computation: color = @get_color(8);

const memcpy = @import_module( "<memcpy_multi/get_params>", .{
    .width = 1,
    .height = size
    });

layout {
  @set_rectangle(1, size);

  var idx :i16 = 0;
  while (idx < size) {
    const memcpy_params = memcpy.get_params(0);

    @set_tile_code(0, idx, "pe_program.csl", .{
      .end_computation = end_computation,
      .memcpy_params = memcpy_params,
      .MEMCPYH2D_DATA_1 = MEMCPYH2D_DATA_1,
      .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
      .LAUNCH = LAUNCH
    });

    idx += 1;
  }

  // export symbol name
  @export_name("f_run", fn(i16)void);
}

pe_program.csl

// Not a complete program; the top-level source file is code.csl.

param memcpy_params: comptime_struct;

param LAUNCH: color;
param MEMCPYH2D_DATA_1: color;
param MEMCPYD2H_DATA_1: color;

// Entrypoint to tell PE that it is time to send the result to the host
param end_computation: color;

const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
     .MEMCPYH2D_1 = MEMCPYH2D_DATA_1,
     .MEMCPYD2H_1 = MEMCPYD2H_DATA_1,
     .LAUNCH = LAUNCH
      }));

// number of wavelets via streaming H2D
var num_wvlts: i16 = 0;

// number of received wavelets
var index: i16 = 0;

var result :f16 = 0.0;

const dsd = @get_dsd(fabout_dsd, .{
  .extent = 1,
  .fabric_color = MEMCPYD2H_DATA_1
});

task main_task(data: f16) void {
  result = result + data;
  index += 1;
  if (index >= num_wvlts){
    // receive all wavelets, send the final result out
    @activate(end_computation);
  }
}

task send_result() void {
  // The non-async operation may not work here because the length is 5.
  // The async operation with microthread 1 is used.
  @fmovh(dsd, result);
}

comptime {
  @bind_task(main_task, MEMCPYH2D_DATA_1);
  @bind_task(send_result, end_computation);
}

// configure the number of wavelets via H2D
fn f_run(size: i16) void {
  num_wvlts = size;

  // WARNING: the user must unblock cmd color for every PE
  sys_mod.unblock_cmd_stream();
}

comptime{
  @export_symbol(f_run);
  @rpc(LAUNCH);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.sdk_utils import memcpy_view
from cerebras.sdk.runtime import runtime_utils # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYH2D_DATA_1 = int(params["MEMCPYH2D_DATA_1_ID"])
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
size = int(params["size"])
print(f"MEMCPYH2D_DATA_1 = {MEMCPYH2D_DATA_1}")
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")
print(f"size = number of PEs in a column = {size}")

memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

runner.load()
runner.run()

num_wvlts = 11
print(f"num_wvlts = number of wavelets for each PE = {num_wvlts}")

print("step 1: call f_run to configure number of input wavelets via H2D")
runner.launch("f_run", np.int16(num_wvlts), nonblock=False)

# Use a deterministic seed so that CI results are predictable
np.random.seed(seed=7)

# Setup a {size}x11 input tensor that is reduced along the second dimension
input_tensor = np.random.rand(size, num_wvlts).astype(np.float16)
expected = np.sum(input_tensor, axis=1)

print("step 2: streaming H2D")
# "input_tensor" is a 1d array
# The type of input_tensor is float16, we need to extend it to uint32
# There are two kind of extension when using the utility function input_array_to_u32
#    input_array_to_u32(np_arr: np.ndarray, sentinel: Optional[int], fast_dim_sz: int)
# 1) zero extension:
#    sentinel = None
# 2) upper 16-bit is the index of the array:
#    sentinel is Not None
#
# In this example, the upper 16-bit is don't care because pe_program.csl only
# reads lower 16-bit
tensors_u32 = runtime_utils.input_array_to_u32(input_tensor.ravel(), 1, num_wvlts)
runner.memcpy_h2d(MEMCPYH2D_DATA_1, tensors_u32, 0, 0, 1, size, num_wvlts, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)

print("step 3: streaming D2H")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(size, np.uint32)
runner.memcpy_d2h(out_tensors_u32, MEMCPYD2H_DATA_1, 0, 0, 1, size, 1, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
result_tensor = memcpy_view(out_tensors_u32, np.dtype(np.float16))

runner.stop()

# Ensure that the result matches our expectation
np.testing.assert_allclose(result_tensor, expected, atol=0.05, rtol=0)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./code.csl --fabric-dims=8,12 \
--fabric-offsets=4,1 -o out \
--params=size:10 \
--params=MEMCPYH2D_DATA_1_ID:0 \
--params=MEMCPYD2H_DATA_1_ID:1 --params=LAUNCH_ID:2 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out