Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions yateto/codegen/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, cpp, arch, target):
self._arch = arch
self._freeList = list()
self._target = target

def create(self, node, *args):
method = 'create_' + node.__class__.__name__
factory = getattr(self, method, self.generic_create)
Expand Down Expand Up @@ -74,25 +74,12 @@ def freeTmp(self):

self._freeList = []

def reset_stream(self):
if self._target == 'cpu':
pass
elif self._target == 'gpu':
self._cpp(f'{BatchedOperationsAux.STREAM_PTR_NAME} = {BatchedOperationsAux.FORBIDDEN_STREAM_PTR};')
else:
raise RuntimeError('unknown compute target')

def reset_flags(self):
if self._target == 'cpu':
pass
elif self._target == 'gpu':
self._cpp(f'{BatchedOperationsAux.FLAGS_NAME} = nullptr;')
else:
raise RuntimeError('unknown compute target')

def _indices(self, var):
shape = var.memoryLayout().shape()
return Indices(string.ascii_lowercase[:len(shape)], shape)

def supportsChainKernels(self):
return False

class OptimizedKernelFactory(KernelFactory):
def __init__(self, cpp, arch, target):
Expand Down Expand Up @@ -289,6 +276,9 @@ def generate(self, cpp, cache):
def add_linear_operation(self, dest, ops, target, permute, add):
pass

def region_switch(self, barrier):
pass

class ExportFactory(KernelFactory):
@classmethod
def makeFactory(cls, generator):
Expand Down Expand Up @@ -364,3 +354,12 @@ def handleLinear(self, dest, ops, add, scalar, transposeA, transposeB):
permute += [[]]

return self.generator.add_linear_operation(dest, ops, target, permute, add)

def region_switch(self, barrier):
self.generator.region_switch(barrier)

def set_region_name(self, name):
self.generator.set_region_name(name)

def supportsChainKernels(self):
return True
91 changes: 78 additions & 13 deletions yateto/codegen/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def deduce_scalar(self, action):
else:
return self.deduce_single_scalar(action.scalar)

def generate(self, cpp, cfg, factory, routineCache, gemm_cfg):
def generate(self, cpp, cfg, factory, routineCache, gemm_cfg):
hwFlops = 0
# temporary memory required (per element in case of gpu)
# NOTE: it is required to know in case if the memory is allocated on the heap
Expand Down Expand Up @@ -131,6 +131,7 @@ def __init__(self,
writable,
prefetch,
scalars,
kernels,
function,
tmp_mem_size,
is_compute_constant_tensors,
Expand All @@ -142,6 +143,7 @@ def __init__(self,
self.writable = writable
self.prefetch = prefetch
self.scalars = scalars
self.kernels = kernels
self.function = function
self.tmp_mem_size = tmp_mem_size
self.is_compute_constant_tensors = is_compute_constant_tensors
Expand Down Expand Up @@ -192,20 +194,73 @@ def generateKernelOutline(self, nonZeroFlops, cfg, gemm_cfg, target):
hwFlops, tmp_memory = super().generate(fcpp, cfg, factory, self._routineCache, gemm_cfg)
factory.post_generate(self._routineCache)
factory.freeTmp()
factory.reset_stream()
factory.reset_flags()
function = functionIO.getvalue()
return self.KernelOutline(nonZeroFlops,
hwFlops,
tensors,
writable,
prefetch,
scalars,
{},
function,
tmp_memory,
is_compute_constant_tensors,
target)

def generateChainKernelOutline(self, definition, gemm_cfg, target):
functionIO = StringIO()
function = ''
with Cpp(functionIO) as fcpp:
factory = self._routine_factories[target](fcpp, self._arch, target)
if factory.supportsChainKernels():
isuper = super()
def generateKernelEntry(name, cfg, arg, index):
factory.set_region_name(f'{name}_kernel[{index}]')
return isuper.generate(fcpp, cfg.cfg, factory, self._routineCache, gemm_cfg)

def region_switch(barrier):
factory.region_switch(barrier)
else:
def generateKernelEntry(name, cfg, arg, index):
if target == 'gpu':
fcpp(f'{name}_kernel[{index}].streamPtr = this->streamPtr;')
fcpp(f'{name}_kernel[{index}].linearAllocator = this->linearAllocator;')
fcpp(f'{name}_kernel[{index}].execute{arg}();')
return 0, 0

def region_switch(barrier):
pass

kernels = {}

tmp_memory = 0
for subdefinition in definition:
for name, kernel, arg in subdefinition:
if name not in kernels:
kernels[name] = []

_, basename = Tensor.splitBasename(name)
_, current_tmp = generateKernelEntry(basename, kernel, arg, len(kernels[name]))
kernels[name] += [(kernel, arg)]
tmp_memory = max(current_tmp, tmp_memory)
region_switch(False)
region_switch(True)
factory.post_generate(self._routineCache)
factory.freeTmp()
function = functionIO.getvalue()

return self.KernelOutline(0,
0,
{},
{},
{},
{},
kernels,
function,
tmp_memory,
{},
target)

@classmethod
def _addFromKO(cls, koEntries, entries):
for key, value in koEntries.items():
Expand All @@ -220,20 +275,19 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None):
prefetch = collections.OrderedDict()
writable = dict()
scalars = collections.OrderedDict()
kernels = collections.OrderedDict()
is_compute_constant_tensors = dict()
for ko in kernelOutlines:
if ko:
self._addFromKO(ko.scalars, scalars)
self._addFromKO(ko.tensors, tensors)
self._addFromKO(ko.kernels, kernels)
self._addFromKO(ko.writable, writable)
self._addFromKO(ko.prefetch, prefetch)
self._addFromKO(ko.is_compute_constant_tensors, is_compute_constant_tensors)

target = kernelOutlines[-1].target
is_same_target = True
for outline in kernelOutlines:
if outline:
is_same_target = True if outline.target == target else False
is_same_target = all(target == outline.target for outline in kernelOutlines)

if not is_same_target:
raise RuntimeError("kernels with the same family belong to different compute target.")
Expand Down Expand Up @@ -283,7 +337,7 @@ def generate(self, cpp, header, name, kernelOutlines, familyStride=None):

header.emptyline()

def kernelArgs(base_name_with_namespace, groups, writable, is_constant, target):
def tensorArgs(base_name_with_namespace, groups, writable, is_constant, target):
prefix, base_name = Tensor.splitBasename(base_name_with_namespace)
typ = self._arch.typename
ptr_type = '**' if not is_constant and target == 'gpu' else '*'
Expand All @@ -305,23 +359,34 @@ def scalarArgs(base_name_with_namespace, groups):
header(f'{class_name}::{container_type} {base_name};')
else:
header(f'{typ} {base_name} = std::numeric_limits<{typ}>::signaling_NaN();')

def kernelArgs(base_name_with_namespace, groups, target):
prefix, base_name = Tensor.splitBasename(base_name_with_namespace)
if len(groups) > 0:
header(f'{base_name_with_namespace} {base_name}_kernel[{len(groups)}];')

for baseName, groups in scalars.items():
scalarArgs(baseName,
groups)
for baseName, groups in tensors.items():
kernelArgs(baseName,
tensorArgs(baseName,
groups,
writable[baseName],
is_compute_constant_tensors[baseName],
target)
for baseName, groups in kernels.items():
kernelArgs(baseName,
groups,
target)
header.emptyline()

# containers with extra offsets for GPU-like computations
if target == 'gpu':
header(f'unsigned {BatchedOperationsAux.NUM_ELEMENTS_NAME} = 0;')
header(f'void *{BatchedOperationsAux.STREAM_PTR_NAME} = {BatchedOperationsAux.FORBIDDEN_STREAM_PTR};')
header(f'unsigned *{BatchedOperationsAux.FLAGS_NAME} = nullptr;')
if len(kernels) == 0: # TODO: better condition
header(f'unsigned {BatchedOperationsAux.NUM_ELEMENTS_NAME} = 0;')
header(f'unsigned* {BatchedOperationsAux.FLAGS_NAME} = nullptr;')

header(f'void* {BatchedOperationsAux.STREAM_PTR_NAME} = {BatchedOperationsAux.FORBIDDEN_STREAM_PTR};')

def generate_extra_offset_args(base_name_with_namespace, groups):
prefix, base_name = Tensor.splitBasename(base_name_with_namespace)
Expand All @@ -341,7 +406,7 @@ def generate_extra_offset_args(base_name_with_namespace, groups):
if len(prefetch) > 0:
with header.Struct(self.PREFETCHSTRUCT_NAME):
for baseName, groups in prefetch.items():
kernelArgs(baseName, groups, writable=False, is_constant=False, target='any')
tensorArgs(baseName, groups, writable=False, is_constant=False, target='any')
header('{} {};'.format(self.PREFETCHSTRUCT_NAME, self.PREFETCHVAR_NAME))
header.emptyline()

Expand Down
83 changes: 83 additions & 0 deletions yateto/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class KernelFamily(object):
GROUP_INDEX = r'\((0|[1-9]\d*)\)'
VALID_NAME = r'^{}({})$'.format(Kernel.BASE_NAME, GROUP_INDEX)

GROUP_INDEX2 = r'\(((?:0|[1-9]\d*)(?:,(?:0|[1-9]\d*))*)\)'
VALID_NAME2 = r'^{}({})$'.format(Kernel.BASE_NAME, GROUP_INDEX2)

def __init__(self, namespace=None):
self._kernels = dict()
self.name = None
Expand All @@ -132,10 +135,19 @@ def baseName(self, name):
def isValidName(cls, name):
return re.match(cls.VALID_NAME, name) is not None

@classmethod
def isValidName2(cls, name):
return re.match(cls.VALID_NAME2, name) is not None

@classmethod
def group(cls, name):
m = re.search(cls.GROUP_INDEX, name)
return int(m.group(1))

def groupLinear(self, name):
m = re.search(self.GROUP_INDEX2, name)
grp = [int(s) for s in m.group(1).split(',')]
return self.linear(self.stride(), grp)

def setStride(self, stride):
self._stride = stride
Expand Down Expand Up @@ -185,6 +197,20 @@ def simpleParameterSpace(*args):
def parameterSpaceFromRanges(*args):
return list(itertools.product(*[list(i) for i in args]))

class KernelChain:
def __init__(self, name, namespace, setup, target):
self.name = name
self.namespace = namespace
self.setup = setup
self.target = target

if self.namespace is None:
self.namespace = ''

@classmethod
def isValidName(cls, name):
return re.match(Kernel.VALID_NAME, name) is not None

class GlobalRoutineCache:
def __init__(self):
self.cache = RoutineCache()
Expand Down Expand Up @@ -231,6 +257,7 @@ def __init__(self, outputDir, name):
def __init__(self, arch):
self._kernels = list()
self._kernelFamilies = dict()
self._kernelChains = list()
self._arch = arch

def arch(self):
Expand Down Expand Up @@ -274,6 +301,24 @@ def addFamily(self,
prefetch = prefetchGenerator(*p) if prefetchGenerator is not None else None
family.add(indexedName, ast, prefetch, namespace, target=target)

def addChain(self, name: str, chainDesc: List[List[str]], namespace=None,
target='cpu'):
"""
Add a kernel chain. I.e. combine multiple kernels into a single one.
Needs support from the routine exporter (to, e.g. merge them into GPU kernels),
otherwise it's equal to just calling all the respective kernels in sequence.

The `chainDesc` parameter is a two-stage list:
* For the first layer, between each list entry, we insert a grid-wide synchronization. I.e. if you were to just launch the kernels in sequence.
* For the second layer, the kernels may potentially overlap (in the GPU case: we only insert block-level synchronization).
* Finally, we have the strings of registered Yateto kernels by the other two functions.

The names will be only checked once you generate the kernels. It's allowed to have a different element count per kernel; moreover,
family kernel instances are allowed as well.
"""

self._kernelChains.append(KernelChain(name, namespace, chainDesc, target))

@classmethod
def _headerGuardName(self, namespace, fileBaseName):
partlist = namespace.upper().split('::') + [fileBaseName.upper(), self.HEADER_GUARD_SUFFIX]
Expand Down Expand Up @@ -348,6 +393,13 @@ def unit_test_body(cpp, testFramework):
kernel_family_dict[family.namespace].append(family)
else:
kernel_family_dict[family.namespace] = [family]

chainkernel_dict = {}
for kernel in self._kernelChains:
if kernel.namespace in chainkernel_dict:
chainkernel_dict[kernel.namespace].append(kernel)
else:
chainkernel_dict[kernel.namespace] = [kernel]

print('Generating kernels...')
if routine_cache is None:
Expand Down Expand Up @@ -395,6 +447,33 @@ def unit_test_body(cpp, testFramework):

with cpp.Namespace(family_namespace), header.Namespace(family_namespace):
optKernelGenerator.generate(cpp, header, family.name, kernelOutlines, family.stride())

# group kernel chains by namespace
for chain_namespace, chainkernels in chainkernel_dict.items():
for chainkernel in chainkernels:

chainDescC = []
for innerDesc in chainkernel.setup:
innerDescC = []
for kernelname in innerDesc:
chainnsp, basename = Tensor.splitBasename(kernelname)
if Kernel.isValidName(basename):
innerDescC += [(kernelname, self._kernels[basename], '')]
elif KernelFamily.isValidName2(basename):
familyname = KernelFamily.baseName(basename)
family = self._kernelFamilies[familyname]
grp = family.groupLinear(basename)

fullfamilyname = f'{chainnsp}kernel::{familyname}'
innerDescC += [(fullfamilyname, family._kernels[grp], grp)]
else:
raise NotImplementedError(f'Invalid name: {kernelname}')
chainDescC += [innerDescC]
kernelOutline = optKernelGenerator.generateChainKernelOutline(chainDescC,
gemm_cfg,
chainkernel.target)
with cpp.Namespace(chain_namespace), header.Namespace(chain_namespace):
optKernelGenerator.generate(cpp, header, chainkernel.name, [kernelOutline])
kernelSourceContent = kernelSource.getvalue()

with Cpp(fKernels.cpp) as cpp:
Expand Down Expand Up @@ -481,3 +560,7 @@ def add(self, *args, **kwargs):
@_add_ns
def addFamily(self, *args, **kwargs):
return self.generator.addFamily(*args, **kwargs)

@_add_ns
def addChain(self, *args, **kwargs):
return self.generator.addChain(*args, **kwargs)