@map Builtin

The @map builtin can be used to perform custom operations on the data elements of one or more DSDs. In other words, it is a customizable DSD operation that allows us to go beyond the fixed list of natively supported DSD operations.

This example demonstrates three use-cases of the @map builtin:

  1. In the first use-case, @map is used to compute the square-root of the diagonal elements of a 2D tensor.

  2. In the second use-case @map is used to perform a custom calculation with a mix of input DSDs of type mem1d_dsd and scalar values while the result is stored to a mem1d_dsd. It shows how we can use arbitrary callbacks combined with a variety of input and output DSDs.

  3. Finally, we demonstrate how @map can be used to compute a reduction like the sum of all elements in a tensor.

Without @map, we would have to write explicit loops iterating over each element involved in these computations. With @map we can avoid writing such loops by utilizing the DSD descriptions which specify the loop structure implicitly. Since DSDs are supported natively by the hardware, using @map can lead to significant performance gains compared to writing explicit loops.

code.csl

// resources to route the data between the host and the device.
//

// color map
//
// color  var    color  var              color  var              color  var
//   0 H2D          9                      18                      27   reserved (memcpy)
//   1 D2H         10                      19                      28   reserved (memcpy)
//   2 LAUNCH      11                      20                      29   reserved
//   3             12                      21    reserved (memcpy) 30   reserved (memcpy)
//   4             13                      22    reserved (memcpy) 31   reserved
//   5             14                      23    reserved (memcpy) 32
//   6             15                      24                      33
//   7             16                      25                      34
//   8 main        17                      26                      35
//

param size: i16;

param MEMCPYH2D_DATA_1_ID: i16;
param MEMCPYD2H_DATA_1_ID: i16;
param LAUNCH_ID: i16;

const MEMCPYH2D_DATA_1: color = @get_color(MEMCPYH2D_DATA_1_ID);
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const LAUNCH: color = @get_color(LAUNCH_ID);

const main_color: color = @get_color(8);

const memcpy = @import_module( "<memcpy_multi/get_params>", .{
    .width = 1,
    .height = 1
    });

layout {
  @set_rectangle(1, 1);

  const memcpy_params = memcpy.get_params(0);

  @set_tile_code(0, 0, "pe_program.csl", .{
    .size = size,
    .main_color = main_color,
    .memcpy_params = memcpy_params,
    .MEMCPYH2D_DATA_1 = MEMCPYH2D_DATA_1,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });

  // export symbol name
  @export_name("weight", [*]f16, true);
  @export_name("sqrt_diag_A", [*]f16, true);
  @export_name("f_run", fn()void);
}

pe_program.csl

// Not a complete program; the top-level source file is code.csl.

param main_color: color;

param size: i16;

param memcpy_params: comptime_struct;
param LAUNCH: color;
param MEMCPYH2D_DATA_1: color;
param MEMCPYD2H_DATA_1: color;

// memcpy module reserves input queue 0 and output queue 0
const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
     .MEMCPYH2D_1 = MEMCPYH2D_DATA_1,
     .MEMCPYD2H_1 = MEMCPYD2H_DATA_1,
     .LAUNCH = LAUNCH
      }));

export const A = @constants([size, size]f16, 42.0);
const B = [size]i16{10, 20, 30, 40, 50};

const math_lib = @import_module("<math>");

var sqrt_diag_A = @zeros([size]f16);
var weight = @zeros([size]f16);

var ptr_weight: [*]f16 = &weight;
var ptr_sqrt_diag_A: [*]f16 = &sqrt_diag_A;

// The loop structure is implicitly specified by the memory DSD descriptions
const dsdA = @get_dsd(mem1d_dsd, .{.tensor_access = |i|{size} -> A[i, i]});
const dsdB = @get_dsd(mem1d_dsd, .{.tensor_access = |i|{size} -> B[i]});

const dsd_sqrt_diag_A = @get_dsd(mem1d_dsd, .{.tensor_access = |i|{size} -> sqrt_diag_A[i]});
const dsd_weight = @get_dsd(mem1d_dsd, .{.tensor_access = |i|{size} -> weight[i]});

var sum : i16 = 0;

fn transformation(value : f16, coeff1 : f16, coeff2 : f16, weight : f16) f16 {
  return value * (coeff1 + weight) + value * (coeff2 + weight);
}

fn reduction(value : i16, sum : *i16) i16 {
  return sum.* + value;
}

task main() void {
  // Compute the square-root of each element of `dsdA` and
  // send it out to `outDSD`.
  //
  // Notice how we avoid writing an explicit loop and rely
  // on the DSD description instead.
  @map(math_lib.sqrt_f16, dsdA, dsd_sqrt_diag_A);

  // Transform tensor A in-place through a custom calculation.
  @map(transformation, dsdA, 2.0, 6.0, dsd_weight, dsdA);

  // Compute the sum of all elements in tensor B.
  @map(reduction, dsdB, &sum, &sum);

  // WARNING: the user must unblock cmd color for every PE
  sys_mod.unblock_cmd_stream();
}

comptime {
  @bind_task(main, main_color);
}

fn f_run() void {
  @activate(main_color);

  // terminate when main() finishes
}

comptime{
  @export_symbol(ptr_weight, "weight");
  @export_symbol(ptr_sqrt_diag_A, "sqrt_diag_A");
  @export_symbol(f_run);
  @rpc(LAUNCH);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.debug.debug_util import debug_util
from cerebras.sdk.sdk_utils import memcpy_view
from cerebras.sdk.runtime import runtime_utils # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument('--cmaddr', help='IP:port for CS system')
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
size = int(params["size"])
print(f"size = {size}")

memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

sym_weight = runner.get_id("weight")
sym_sqrt_diag_A = runner.get_id("sqrt_diag_A")

runner.load()
runner.run()

A = np.array([[42.0, 42.0, 42.0, 42.0, 42.0],
              [42.0, 42.0, 42.0, 42.0, 42.0],
              [42.0, 42.0, 42.0, 42.0, 42.0],
              [42.0, 42.0, 42.0, 42.0, 42.0],
              [42.0, 42.0, 42.0, 42.0, 42.0]]).astype(np.float16)
B = np.array([10, 20, 30, 40, 50]).astype(np.int16)

def transformation(value: np.array, coeff1: float, coeff2: float, weight: np.array):
  return np.multiply(value, coeff1 + weight) + np.multiply(value, coeff2 + weight)

def reduction(array):
  return sum(array)

np.random.seed(seed=7)

print("step 1: copy mode H2D")
weights = np.random.random(size).astype(np.float16)
tensors_u32 = runtime_utils.input_array_to_u32(weights, 0, size)
runner.memcpy_h2d(sym_weight, tensors_u32, 0, 0, 1, 1, size, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=True)

print("step 2: call f_run to test @map")
runner.launch("f_run", nonblock=False)

print("step 3: copy mode D2H")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(size, np.uint32)
runner.memcpy_d2h(out_tensors_u32, sym_sqrt_diag_A, 0, 0, 1, 1, size, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
sqrt_result = memcpy_view(out_tensors_u32, np.dtype(np.float16))

runner.stop()

expected = np.sqrt(np.diag(A))
np.testing.assert_equal(sqrt_result, expected)

debug_mod = debug_util(dirname, cmaddr=args.cmaddr)
core_offset_x = 4
core_offset_y = 1
print(f"=== dump core: core rectangle starts at {core_offset_x}, {core_offset_y}")

# Transformation example
expected = transformation(np.diag(A), 2.0, 6.0, weights)
np.fill_diagonal(A, expected)
actual = debug_mod.get_symbol(core_offset_x, core_offset_y, "A", np.float16)
np.testing.assert_equal(actual.reshape((5, 5)), A)

# Reduction example
sum_result = np.array([reduction(B)], dtype=np.int16)
expected = debug_mod.get_symbol(core_offset_x, core_offset_y, "sum", np.int16)
np.testing.assert_equal(sum_result, expected)

print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./code.csl \
--fabric-dims=8,3 --fabric-offsets=4,1 \
--params=size:5 \
-o out \
--params=MEMCPYH2D_DATA_1_ID:0 \
--params=MEMCPYD2H_DATA_1_ID:1 --params=LAUNCH_ID:2 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out