Wavelets for Sparse Tensors

Wavelets for Sparse Tensors

When tensors are sparse, it is wasteful to send zero values. Since wavelet payloads are 32 bits wide, we can use the lower 16 bits to contain data as usual, but we can also use the upper 16 bits to contain the index of the value.

This example illustrates the latter, where each wavelet of the incoming tensor has the index field populated in the upper 16 bits. Accordingly, the task definition uses two function arguments, one for the lower 16 bits whereas another for the upper 16 bits.

Optionally, the programmer may also declare a task with just one argument of type u32 for receiving 32-bit data.

code.csl

// The core kernel must start at P4.1 so the memcpy infrastructure has enough
// resources to route the data between the host and the device.
//

// color map
//
// color  var    color  var              color  var              color  var
//   0 H2D          9                      18                      27   reserved (memcpy)
//   1 D2H         10                      19                      28   reserved (memcpy)
//   2 LAUNCH      11                      20                      29   reserved
//   3             12                      21    reserved (memcpy) 30   reserved (memcpy)
//   4             13                      22    reserved (memcpy) 31   reserved
//   5             14                      23    reserved (memcpy) 32
//   6             15                      24                      33
//   7             16                      25                      34
//   8             17                      26                      35
//

param MEMCPYH2D_DATA_1_ID: i16;
param MEMCPYD2H_DATA_1_ID: i16;
param LAUNCH_ID: i16;

const MEMCPYH2D_DATA_1: color = @get_color(MEMCPYH2D_DATA_1_ID);
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const LAUNCH: color = @get_color(LAUNCH_ID);

const memcpy = @import_module( "<memcpy_multi/get_params>", .{
    .width = 1,
    .height = 1
    });

layout {
  @set_rectangle(1, 1);

  const memcpy_params = memcpy.get_params(0);

  @set_tile_code(0, 0, "pe_program.csl",  .{
    .memcpy_params = memcpy_params,
    .MEMCPYH2D_DATA_1 = MEMCPYH2D_DATA_1,
    .MEMCPYD2H_DATA_1 = MEMCPYD2H_DATA_1,
    .LAUNCH = LAUNCH
  });

}

pe_program.csl

// Not a complete program; the top-level source file is code.csl.

param memcpy_params: comptime_struct;

param LAUNCH: color;
param MEMCPYH2D_DATA_1: color;
param MEMCPYD2H_DATA_1: color;

const sys_mod = @import_module( "<memcpy_multi/memcpy>", @concat_structs(memcpy_params, .{
     .MEMCPYH2D_1 = MEMCPYH2D_DATA_1,
     .MEMCPYD2H_1 = MEMCPYD2H_DATA_1,
     .LAUNCH = LAUNCH
      }));

var result = [4]i16 { 0, 0, 0, 0 };

const dsd = @get_dsd(fabout_dsd, .{
   .extent = 1,
   .fabric_color = MEMCPYD2H_DATA_1
});

task main_task(wavelet_data: i16, index: i16) void {
  result[index] = wavelet_data;
  // The non-async operation works here because only two wavelet are sent
  // It would be better to use async operation with .{async = true}
  @mov16(dsd, wavelet_data);
}

comptime {
  @bind_task(main_task, MEMCPYH2D_DATA_1);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.sdk_utils import memcpy_view
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYH2D_DATA_1 = int(params["MEMCPYH2D_DATA_1_ID"])
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
print(f"MEMCPYH2D_DATA_1 = {MEMCPYH2D_DATA_1}")
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")

memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

runner.load()
runner.run()

# Turn each tuple of two 16-bit integers into one 32-bit integer
packed = [(idx << 16) + val for idx, val in [(0, 42), (3, 26)]]
packed_tensor = np.array(packed, dtype=np.int32)

print("step 1: streaming H2D")
# "packed_tensor" must be an 1d array of type u32
runner.memcpy_h2d(MEMCPYH2D_DATA_1, packed_tensor, 0, 0, 1, 1, 2, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=True)

print("step 2: streaming D2H")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(2, np.uint32)
runner.memcpy_d2h(out_tensors_u32, MEMCPYD2H_DATA_1, 0, 0, 1, 1, 2, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
result_tensor = memcpy_view(out_tensors_u32, np.dtype(np.int16))

runner.stop()

# Ensure that the result matches our expectation
# Since zero wavelets are skipped during transmission, the `@mov16` operation
# in the code is executed only twice, once for each non-zero wavelet data
np.testing.assert_equal(result_tensor, [42, 26])
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./code.csl --fabric-dims=8,3 \
--fabric-offsets=4,1 -o out \
--params=MEMCPYH2D_DATA_1_ID:0 \
--params=MEMCPYD2H_DATA_1_ID:1 --params=LAUNCH_ID:2 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out