Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/yateto-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
- name: Codegen Tests
run: |
cd ./tests/code-gen
for example in matmul minimal indices slicing; do
for example in matmul minimal indices slicing regress; do
for build_type in Debug Release; do
for precision in single double; do
echo " ====== Test Config: ======"
Expand Down
36 changes: 36 additions & 0 deletions tests/code-gen/regress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env python3

from yateto import *
from yateto.ast.node import Add

def add(g):
M = 32
K = 40
A = Tensor('A', (M, K))
B = Tensor('B', (M, K))
C = Tensor('C', (M, K))

class Counter:
def __init__(self):
self.counter = 0

counter = Counter()

def _(kernel):
counter.counter += 1
g.add(f'kernel{counter.counter}', kernel)

# regression tests

# list bugs with their PR solving them here

# #103.1
# allow one-element sum accumulations
_(B['ij'] <= Add() + A['ij'])

# #103.2
# prevent overriding a global variable when action merging
_([
B['ij'] <= A['ij'],
C['ij'] <= B['ij']
])
6 changes: 5 additions & 1 deletion yateto/controlflow/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def visit(self, cfg):
V = ua.variables()
for j in range(i+1,n):
va = cfg[j].action
if va.isRHSVariable() and ua.result == va.term and va.result not in V and (ua.hasTrivialScalar() or va.hasTrivialScalar()):
if va.isRHSVariable() \
and ua.result == va.term \
and va.result not in V \
and (ua.hasTrivialScalar() or va.hasTrivialScalar()) \
and ua.result.isLocal():
found = j
break
elif ua.result in va.variables() or ua.result == va.result:
Expand Down
2 changes: 1 addition & 1 deletion yateto/controlflow/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def visit_SliceView(self, node):

def visit_Add(self, node):
variables = [self.visit(child) for child in node]
assert len(variables) > 1
assert len(variables) >= 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, could you please elaborate a bit on this change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The computation also works for length 1 without adjustments (but not so for 0 I think, but I haven't tried it); so all we do here is make the assert more lenient for that exact case.

(it doesn't quite return the original tensor here; but will still generate a copy-scale-add operation—though that one should be, IIRC, optimized away or merged with other operations afterwards)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we verify that it does not work for 0 first? If it also works for 0, maybe you could remove the assert.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried it—it seems to fail on some more fronts (i.e. it takes longer than 10 mins to fix it).


variables.sort(key=lambda var: int(not var.writable) + int(not var.isGlobal()))

Expand Down
2 changes: 1 addition & 1 deletion yateto/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def unit_test_body(cpp, testFramework):
cpp.include(fInit.hName)
with cpp.Namespace(namespace):
initGen.generateInitCpp(cpp)

prefixnsp = lambda a: a.name if a.namespace == '' else f'{a.namespace}::{a.name}'
return {
'namespace': namespace,
Expand Down
8 changes: 4 additions & 4 deletions yateto/metagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def headerForward(name, data):
for entry in data:
self.template(header, entry, data[entry], f'{name}')


headerForward('tensor', tensors)
headerForward('init', tensors)
headerForward('kernel', kernels)
Expand All @@ -105,7 +105,7 @@ def cppForward(name):
for gendata in self.generators:
outdirname = f'metagen_{gendata["name"]}'
header.include(f'{outdirname}/{name}.cpp')

cppForward('tensor')
cppForward('init')
cppForward('kernel')
Expand All @@ -131,12 +131,12 @@ def inner():

templatetypes = ', '.join(f'{typ} Arg{i}' for i, typ in enumerate(self.templateType))
templateargs = ', '.join(f'Arg{i}' for i, _ in enumerate(self.templateType))

with header.Namespace('internal'):
header(f'template<{templatetypes}> struct {internalName} {"{"} using Type = void; {"}"};')
for gnsp, spec in foundin:
spectext = ', '.join(str(specpart) for specpart in spec)
header(f'template<> struct {internalName}<{spectext}> {"{"} using Type = ::{gnsp}::{fullname}; {"}"};')
header(f'template<{templatetypes}> using {name} = typename internal::{internalName}<{templateargs}>::Type;')

self.namespacing(header, splitname[:-1] + [subnsp], inner)
Loading