Filters

Fabric filters allow a PE to selectively accept incoming wavelets. This example shows the use of so-called range filters, which specify the wavelets to allow to be forwarded to the CE based on the upper 16 bits of the wavelet contents. Specifically, PE #0 sends all 12 wavelets to the other PEs, while each recipient PE receives and processes only a quarter of the incoming wavelets. See the documentation for other possible filter configurations.

code.csl

// resources to route the data between the host and the device.
//

// color map
//
// color  var    color  var          color  var              color  var
//   0              9  dataColor       18                      27   reserved (memcpy)
//   1 D2H         10                  19                      28   reserved (memcpy)
//   2 LAUNCH      11                  20                      29   reserved
//   3             12                  21    reserved (memcpy) 30   reserved (memcpy)
//   4             13                  22    reserved (memcpy) 31   reserved
//   5             14                  23    reserved (memcpy) 32
//   6             15                  24                      33
//   7             16                  25                      34
//   8 mainColor   17                  26                      35
//

param MEMCPYD2H_DATA_1_ID: i16;
param LAUNCH_ID: i16;

const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const LAUNCH: color = @get_color(LAUNCH_ID);

const mainColor: color = @get_color(8);
const dataColor: color = @get_color(9);

const memcpy = @import_module( "<memcpy_multi/get_params>", .{
    .width = 4,
    .height = 1
    });

layout {
  @set_rectangle(4, 1);

  const memcpy_params_0 = memcpy.get_params(0);
  const memcpy_params_1 = memcpy.get_params(1);
  const memcpy_params_2 = memcpy.get_params(2);
  const memcpy_params_3 = memcpy.get_params(3);

  @set_tile_code(0, 0, "send.csl", .{
    .peId = 0,
    .mainColor = mainColor,
    .exchColor = dataColor,
    .memcpy_params = memcpy_params_0,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });

  const recvStruct = .{
    .recvColor = dataColor,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  };
  @set_tile_code(1, 0, "recv.csl", @concat_structs(recvStruct, .{ .peId = 1, .memcpy_params = memcpy_params_1 }));
  @set_tile_code(2, 0, "recv.csl", @concat_structs(recvStruct, .{ .peId = 2, .memcpy_params = memcpy_params_2 }));
  @set_tile_code(3, 0, "recv.csl", @concat_structs(recvStruct, .{ .peId = 3, .memcpy_params = memcpy_params_3 }));

  // export symbol name
  @export_name("buf", [*]f16, true);
  @export_name("f_run", fn()void);
}

send.csl

// Not a complete program; the top-level source file is code.csl.

param peId: u16;

param mainColor: color;
param exchColor: color;

param memcpy_params: comptime_struct;
param LAUNCH: color;
param MEMCPYD2H_DATA_1: color;

const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
     .MEMCPYD2H_1 = MEMCPYD2H_DATA_1,
     .LAUNCH = LAUNCH
      }));


/// Helper function to pack 16-bit index and 16-bit float value into one 32-bit
/// wavelet.
fn pack(index: u16, data: f16) u32 {
  return (@as(u32, index) << 16) | @as(u32, @bitcast(u16, data));
}

const size = 12;
const data = [size]u32 {
  pack(0, 10.0),  pack( 1, 11.0), pack( 2, 12.0),
  pack(3, 13.0),  pack( 4, 14.0), pack( 5, 15.0),
  pack(6, 16.0),  pack( 7, 17.0), pack( 8, 18.0),
  pack(9, 19.0),  pack(10, 20.0), pack(11, 21.0),
};

/// Function to send all data values to all east neighbors.
fn sendDataToEastTiles() void {
  const inDsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{size} -> data[i]
  });

  const outDsd = @get_dsd(fabout_dsd, .{
    .extent = size,
    .fabric_color = exchColor,
  });

  @mov32(outDsd, inDsd);
}


const num_wvlts: i16 = 3;
var buf = @zeros([num_wvlts]f16);
var ptr_buf : [*]f16 = &buf;

// Function to process (divide by 2) the first three values
fn processAndSendSubset() void {

  var idx: u16 = 0;
  while (idx < 3) : (idx += 1) {
    const payload = @as(u16, data[idx] & 0xffff);
    const floatValue = @bitcast(f16, payload);
    buf[idx] = floatValue / 2.0;
  }
}

task mainTask() void {
  // broadcast to all PEs, including itself
  sendDataToEastTiles();
  // prepare data in "buf"
  processAndSendSubset();

  // WARNING: the user must unblock cmd color for every PE
  sys_mod.unblock_cmd_stream();
}

comptime {
  @bind_task(mainTask, mainColor);

  @set_local_color_config(exchColor, .{ .routes = .{ .rx = .{ RAMP }, .tx = .{ EAST } } });
}

// only sender triggers the broadcasting
fn f_run() void {
  @activate(mainColor);
  // terminate when the mainTask is done
}

comptime{
  @export_symbol(ptr_buf, "buf");
  @export_symbol(f_run);
  @rpc(LAUNCH);
}

recv.csl

// Not a complete program; the top-level source file is code.csl.

param peId: u16;

param recvColor: color;

param memcpy_params: comptime_struct;
param LAUNCH: color;
param MEMCPYD2H_DATA_1: color;

const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
     .MEMCPYD2H_1 = MEMCPYD2H_DATA_1,
     .LAUNCH = LAUNCH
      }));

const num_wvlts: i16 = 3;
var index: i16 = 0;

var buf = @zeros([num_wvlts]f16);
var ptr_buf : [*]f16 = &buf;

// The recipient simply halves the value in the incoming wavelet
task recvTask(data: f16) void {
  buf[index] = data / 2.0;
  index += 1;
  if (index >= num_wvlts){
     // receive all wavelets, proceed next command
     // WARNING: the user must unblock cmd color for every PE
     sys_mod.unblock_cmd_stream();
  }
}

comptime {
  @bind_task(recvTask, recvColor);

  // f_run() unblocks this color to receive the broadcasting value
  @block(recvColor);

  const baseRoute = .{
    .rx = .{ WEST }
  };

  const filter = .{
      // Each PE should only accept three wavelets starting with the one whose
      // index field contains the value peId * 3.
      .kind = .{ .range = true },
      .min_idx = peId * 3,
      .max_idx = peId * 3 + 2,
    };

  if (peId == 3) {
    // This is the last PE, don't forward the wavelet further to the east.
    const txRoute = @concat_structs(baseRoute, .{ .tx = .{ RAMP } });
    @set_local_color_config(recvColor, .{.routes = txRoute, .filter = filter});
  } else {
    // Otherwise, forward incoming wavelets to both CE and to the east neighbor.
    const txRoute = @concat_structs(baseRoute, .{ .tx = .{ RAMP, EAST } });
    @set_local_color_config(recvColor, .{.routes = txRoute, .filter = filter});
  }
}

// only sender triggers the broadcasting
// receiver unblocks recvColor to receive the data from the sender
fn f_run() void {
  // starts to receive the data from the sender
  @unblock(recvColor);

  // terminates only when all wavelets are received
}

comptime{
  @export_symbol(ptr_buf, "buf");
  @export_symbol(f_run);
  @rpc(LAUNCH);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.sdk_utils import memcpy_view
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")

memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

sym_buf = runner.get_id("buf")

runner.load()
runner.run()

print("step 1: call f_run to start broadcasting")
runner.launch("f_run", nonblock=False)

print("step 2: copy D2H")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(4*3, np.uint32)
runner.memcpy_d2h(out_tensors_u32, sym_buf, 0, 0, 4, 1, 3, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
result = memcpy_view(out_tensors_u32, np.dtype(np.float16))

runner.stop()

oracle = [5, 5.5, 6, 6.5, 7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5]
np.testing.assert_allclose(result, oracle, atol=0.0001, rtol=0)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./code.csl --fabric-dims=11,3 --fabric-offsets=4,1 -o out \
--params=MEMCPYD2H_DATA_1_ID:1 --params=LAUNCH_ID:2 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out