Switches

In all previous examples, all routes were static, i.e. once defined, they were set for the lifetime of the program. Fabric switches permit limited runtime control of routes.

In this example, the layout block initializes the default route to receive wavelets from the ramp and forward them to the PE’s north neighbor. However, it also defines routes for switch positions 1, 2, and 3. The hardware updates the route according to the specified switch positions when it receives a so-called Control Wavelet.

For the payload of the control wavelet, the code creates a special wavelet using the helper function ctrl(), which will become a library function in the future.

Switches can be helpful not just to change the routing configuration in limited ways at runtime, but also to save the number of colors used. For instance, this same example could be re-written to use four colors and four routes, but by using fabric switches, this example uses just one color.

code.csl

// resources to route the data between the host and the device.
//

// color map
//
// color  var    color  var              color  var              color  var
//   0              9  channel             18                      27   reserved (memcpy)
//   1 D2H         10                      19                      28   reserved (memcpy)
//   2 LAUNCH      11                      20                      29   reserved
//   3             12                      21    reserved (memcpy) 30   reserved (memcpy)
//   4             13                      22    reserved (memcpy) 31   reserved
//   5             14                      23    reserved (memcpy) 32
//   6             15                      24                      33
//   7             16                      25                      34
//   8 start       17                      26                      35
//

param MEMCPYD2H_DATA_1_ID: i16;
param LAUNCH_ID: i16;

const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const LAUNCH: color = @get_color(LAUNCH_ID);

const start: color = @get_color(8);

const colorValue = 9;
const channel: color = @get_color(colorValue);

const memcpy = @import_module( "<memcpy_multi/get_params>", .{
    .width = 3,
    .height = 3
    });

layout {
  @set_rectangle(3, 3);

  // Out of the nine PEs, the PE in the center (PE #1,1) will send four
  // control wavelets to the PE's four adjacent neighbors.  These four
  // adjacent numbers are programmed to receive the control wavelets, whereas
  // all other PEs (i.e. the PEs at the corners of the rectangle) are
  // programmed to contain no instructions or routes.

  const memcpy_params_0 = memcpy.get_params(0);
  const memcpy_params_1 = memcpy.get_params(1);
  const memcpy_params_2 = memcpy.get_params(2);

  @set_tile_code(1, 1, "send.csl", .{
    .mainColor = start, .txColor = channel, .colorValue = colorValue,
    .memcpy_params = memcpy_params_1,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });

  @set_tile_code(1, 0, "recv.csl", .{
    // Make this PE send the final message back to the host signaling completion
    .rxColor = channel, .inDir = SOUTH, .fin = true,
    .memcpy_params = memcpy_params_1,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });

  @set_tile_code(0, 1, "recv.csl", .{
    .rxColor = channel, .inDir = EAST, .fin = false,
    .memcpy_params = memcpy_params_0,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });

  @set_tile_code(2, 1, "recv.csl", .{
    .rxColor = channel, .inDir = WEST, .fin = false,
    .memcpy_params = memcpy_params_2,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });

  @set_tile_code(1, 2, "recv.csl", .{
    .rxColor = channel, .inDir = NORTH, .fin = false,
    .memcpy_params = memcpy_params_1,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });

  @set_tile_code(0, 0, "launch.csl", .{
    .memcpy_params = memcpy_params_0,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });
  @set_tile_code(2, 0, "launch.csl", .{
    .memcpy_params = memcpy_params_2,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });
  @set_tile_code(0, 2, "launch.csl", .{
    .memcpy_params = memcpy_params_0,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });
  @set_tile_code(2, 2, "launch.csl", .{
    .memcpy_params = memcpy_params_2,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });

  // export symbol name
  @export_name("f_run", fn()void);
}

send.csl

// Not a complete program; the top-level source file is code.csl.

param txColor: color;
param mainColor: color;
param colorValue;

param memcpy_params: comptime_struct;

param LAUNCH: color;
param MEMCPYD2H_DATA_1: color;

const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
     .MEMCPYD2H_1 = MEMCPYD2H_DATA_1,
     .LAUNCH = LAUNCH
      }));

const dsd = @get_dsd(fabout_dsd, .{
  .extent = 1,
  .fabric_color = txColor,

  // Specify that this wavelet is a control wavelet
  .control = true,
});

// Opcodes for potentially updating switches
const opcode_nop = 0;
const opcode_switch_advance = 1;
const opcode_switch_reset = 2;
const opcode_teardown = 3;

// Helper function to construct the payload of the control wavelet.
// args:
//    ceFilter: a filter bit to disable transmission from the destination
//              router to the destination CE,
//    opcode: switching opcode (see comment above), and
//    data: 16-bit wavelet data
fn ctrl(ce_filter: bool, opcode: i16, data: u16) u32 {
  const six = @as(u32, 6);
  const eight = @as(u32, 8);
  const sixteen = @as(u32, 16);

  const hi_word = @as(u32, colorValue) |
                  @as(u32, opcode) << six |
                  @as(u32, ce_filter) << eight;

  const lo_word = @as(u32, data);
  return hi_word << sixteen | lo_word;
}

task mainTask() void {
  // Now we can reuse a single color to send four different values to the four
  // neighbors of this PE.  The four wavelets will be sent over four
  // consecutive cycles.

  // Send 0xaa along the first (WEST) direction
  // Since all arguments to this function are known at compile time, we make
  // this a `comptime` function call.
  @mov32(dsd, comptime ctrl(false, opcode_switch_advance, 0xaa));

  // Send 0xbb along the second (EAST) direction
  @mov32(dsd, comptime ctrl(false, opcode_switch_advance, 0xbb));

  // Send 0xcc along the third (SOUTH) direction
  @mov32(dsd, comptime ctrl(false, opcode_switch_advance, 0xcc));

  // Send 0xdd along the fourth (NORTH) direction
  @mov32(dsd, comptime ctrl(false, opcode_switch_advance, 0xdd));
}

comptime {
  @bind_task(mainTask, mainColor);

  const routes = .{
    // The default route, which is to receive from ramp and send to north
    .rx = .{ RAMP },
    .tx = .{ NORTH }
  };

  const switches = .{

    // Upon a control wavelet, change the transmit direction to west
    .pos1 = .{ .tx = WEST },

    // Upon another control wavelet, change the transmit direction to east
    .pos2 = .{ .tx = EAST },

    // Upon yet another control wavelet, change the transmit direction to south
    .pos3 = .{ .tx = SOUTH },

    // Send to west PE first, then east PE, then south PE, and then north PE
    .current_switch_pos = 1,

    // Wrap around from position 3 to position 0 after receiving control wavelet
    .ring_mode = true,
  };

  @set_local_color_config(txColor, .{.routes = routes, .switches = switches});
}

// only sender.csl triggers the data sending
fn f_run() void {
  @activate(mainColor);

  // RPC returns early before the data is sent out via D2H color
  // The host must wait for streaming D2H

  // WARNING: the user must unblock cmd color for every PE
  sys_mod.unblock_cmd_stream();
}

comptime{
  @export_symbol(f_run);
  @rpc(LAUNCH);
}

recv.csl

// Not a complete program; the top-level source file is code.csl.

param rxColor: color;

param fin: bool;
param inDir: direction;

param memcpy_params: comptime_struct;

param LAUNCH: color;
param MEMCPYD2H_DATA_1: color;

const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
     .MEMCPYD2H_1 = MEMCPYD2H_DATA_1,
     .LAUNCH = LAUNCH
      }));

const dsd = @get_dsd(fabout_dsd, .{.fabric_color = MEMCPYD2H_DATA_1, .extent = 1});

export var global:u16 = 0;

task rxTask(data: u16) void {
  global = data;

  if (fin) {
    @mov16(dsd, 0);
  }
}

comptime {
  @bind_task(rxTask, rxColor);
  @set_local_color_config(rxColor, .{.routes = .{ .rx = .{ inDir }, .tx = .{ RAMP } } });
}

// only sender.csl triggers the data sending
fn f_run() void {
  // WARNING: the user must unblock cmd color for every PE
  sys_mod.unblock_cmd_stream();
}

comptime{
  @export_symbol(f_run);
  @rpc(LAUNCH);
}

launch.csl

// Not a complete program; the top-level source file is code.csl.

param memcpy_params: comptime_struct;

param LAUNCH: color;
param MEMCPYD2H_DATA_1: color;

const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
     .MEMCPYD2H_1 = MEMCPYD2H_DATA_1,
     .LAUNCH = LAUNCH
      }));

// only sender.csl triggers the data sending
fn f_run() void {
  // WARNING: the user must unblock cmd color for every PE
  sys_mod.unblock_cmd_stream();
}

comptime{
  @export_symbol(f_run);
  @rpc(LAUNCH);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.debug.debug_util import debug_util
from cerebras.sdk.sdk_utils import memcpy_view
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")

memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

runner.load()
runner.run()

print("step 1: call f_run to start at PE1.1")
runner.launch("f_run", nonblock=False)

print("step 2: streaming D2H at P1.0 (end of communication)")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(1, np.uint32)
runner.memcpy_d2h(out_tensors_u32, MEMCPYD2H_DATA_1, 1, 0, 1, 1, 1, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
result_tensor = memcpy_view(out_tensors_u32, np.dtype(np.int16))

runner.stop()

debug_mod = debug_util(dirname, cmaddr=args.cmaddr)
core_offset_x = 4
core_offset_y = 1
print(f"=== core rectangle starts at {core_offset_x}, {core_offset_y}")
# sender PE is P1.1
# top PE of sender PE is P1.0
result_top = debug_mod.get_symbol(core_offset_x+1, core_offset_y+0, "global", np.uint16)
# left PE of sender PE is P0.1
result_left = debug_mod.get_symbol(core_offset_x+0, core_offset_y+1, "global", np.uint16)
# right PE of sender PE is P2.1
result_right = debug_mod.get_symbol(core_offset_x+2, core_offset_y+1, "global", np.uint16)
# bottom PE of sender PE is P1.2
result_bottom = debug_mod.get_symbol(core_offset_x+1, core_offset_y+2, "global", np.uint16)

np.testing.assert_allclose(result_top, 0xdd)
np.testing.assert_allclose(result_left, 0xaa)
np.testing.assert_allclose(result_right, 0xbb)
np.testing.assert_allclose(result_bottom, 0xcc)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./code.csl --fabric-dims=10,5 --fabric-offsets=4,1 -o out \
--params=MEMCPYD2H_DATA_1_ID:1 --params=LAUNCH_ID:2 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out