Skip to content

Commit ce7cf3c

Browse files
committed
Add struct return
1 parent f1cb9dc commit ce7cf3c

File tree

5 files changed

+124
-43
lines changed

5 files changed

+124
-43
lines changed

examples/start/main.asa

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,21 @@ create :: string(c : *char) #inline; {
3131

3232
s.tS.x = 5;
3333

34-
len : uint32 = 0;
34+
len : uint32 = 15;
3535
//for(i : 0..4_294_967_296){
3636
// //if(c[i] == '\0')
3737
// // break;
3838
// len++;
3939
//}
40-
//s.address = malloc(len);
40+
s.address = malloc(len);
4141
//len = 5;
4242
//s.address = malloc(5);
43-
//for(i : 0..len){
44-
// s.address[i] = 'A';
45-
// //s.address[i] = c[i];
46-
//}
47-
//s.length = len;
43+
for(i : 0..len){
44+
s.address[i] = 'A';
45+
//s.address[i] = c[i];
46+
}
47+
s.length = len;
48+
s.address = "String value example";
4849
puts("Done!\n");
4950
return s;
5051
}
@@ -108,6 +109,7 @@ main :: (){
108109
//printint(s.length);
109110
//s.size();
110111
puts("printing string...");
112+
puts(s.address);
111113
s.print();
112114
puts("Done!");
113115

modules/Builtin/asa.asa

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,15 @@ Strings :: module{
175175
//}
176176

177177
}
178+
179+
PipeOperatorTest :: module{
180+
fnX :: int(x : int){
181+
return x*5;
182+
}
183+
fnY :: int(y : int){
184+
return y*5;
185+
}
186+
pipeTest::int(z : int){
187+
//return fnX(z) -> fnY(%);
188+
}
189+
}

src/codegen.cpp

Lines changed: 100 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -68,20 +68,24 @@ struct functionID {
6868
std::string mangledName = "";
6969
std::string returnType = "";
7070
argumentList arguments = argumentList();
71+
argumentList userArguments = argumentList();
7172
bool variableNumArguments = false;
73+
bool isStructReturn = false;
7274
uint32_t uses = 0;
7375
bool isMemberFunction = false;
7476
Function* fnValue = nullptr;
7577
functionID() {}
76-
functionID(std::string n, std::string mN, std::string r, argumentList a, Function* f, bool vA = false, bool mF = false)
78+
functionID(std::string n, std::string mN, std::string r, argumentList llvmArgs, argumentList userArgs, Function* f, bool vA = false, bool mF = false, bool sRet = false)
7779
{
7880
name = n;
7981
mangledName = mN;
8082
returnType = r;
81-
arguments = a;
83+
arguments = llvmArgs;
84+
userArguments = userArgs;
8285
fnValue = f;
8386
variableNumArguments = vA;
8487
isMemberFunction = mF;
88+
isStructReturn = sRet;
8589
}
8690
void print()
8791
{
@@ -123,12 +127,12 @@ struct functionID {
123127
uint16_t differences = 0;
124128
if (n != name)
125129
return 1000;
126-
if (arguments.size() != a.size())
130+
if (userArguments.size() != a.size())
127131
return 1000 - 1;
128-
for (int i = 0; i < arguments.size(); i++) {
129-
ASTNodeType t1 = arguments[i].baseASTType;
132+
for (int i = 0; i < userArguments.size(); i++) {
133+
ASTNodeType t1 = userArguments[i].baseASTType;
130134
ASTNodeType t2 = a[i].baseASTType;
131-
bool mustBeExactType = arguments[i].mustBeExactType;
135+
bool mustBeExactType = userArguments[i].mustBeExactType;
132136
// If t1 is an integer type, make sure t2 is also
133137
// Difference points are given the further the types are
134138

@@ -163,12 +167,12 @@ struct functionID {
163167
uint16_t differences = 0;
164168
if (n != name)
165169
return 1000;
166-
if (arguments.size() != a.size())
170+
if (userArguments.size() != a.size())
167171
return 1000 - 1;
168-
for (int i = 0; i < arguments.size(); i++) {
169-
ASTNodeType t1 = arguments[i].baseASTType;
172+
for (int i = 0; i < userArguments.size(); i++) {
173+
ASTNodeType t1 = userArguments[i].baseASTType;
170174
ASTNodeType t2 = a[i]->nodeType;
171-
bool mustBeExactType = arguments[i].mustBeExactType;
175+
bool mustBeExactType = userArguments[i].mustBeExactType;
172176
// If t1 is an integer type, make sure t2 is also
173177
// Difference points are given the further the types are
174178

@@ -841,8 +845,6 @@ void* ASTNode::generateConstant(int pass)
841845
return ConstantFP::get(*TheContext, APFloat(stod(token->first)));
842846
else if (nodeType == String_Constant_Node) {
843847
std::string strValue = unescapeString(token->first.substr(1, token->first.size() - 2), token); // remove quotes from token
844-
// Add null terminator
845-
strValue += '\0';
846848

847849
GlobalVariable* globalStr = nullptr;
848850
if (globalStringLiteralConstants.find(strValue) != globalStringLiteralConstants.end())
@@ -1873,15 +1875,17 @@ void* ASTNode::generateCallExpression(int pass)
18731875
ASTNode* argsNode = childNodes[0];
18741876
bool shouldBeMemberFunction = isCallMemberFunction;
18751877
isCallMemberFunction = false;
1878+
18761879
std::vector<ASTNode*> args = std::vector<ASTNode*>();
18771880
for (auto& a : argsNode->childNodes)
18781881
if (a->childNodes.size() > 0) {
18791882
args.push_back(a);
1880-
//aList.push_back(std::make_pair());
18811883
}
18821884

18831885
std::vector<Value*> ArgsV = std::vector<Value*>();
18841886
argumentList argList = argumentList();
1887+
1888+
// Build argList and ArgsV WITHOUT sret initially
18851889
for (int i = 0; i < args.size(); i++) {
18861890
Value* argVal = (Value*)(args[i]->*(args[i]->codegen))(pass);
18871891
ArgsV.push_back(argVal);
@@ -1890,10 +1894,8 @@ void* ASTNode::generateCallExpression(int pass)
18901894
return nullptr;
18911895
}
18921896

1893-
1894-
// Look up the id in the global module table.
1897+
// Look up the function ID using the caller's argList (without sret)
18951898
functionID* CalleeFID = getFunctionFromID(functionIDs, token->first, argList, token, true, shouldBeMemberFunction);
1896-
//Function* CalleeF = TheModule->getFunction(token->first);
18971899
if (!CalleeFID) {
18981900
if (shouldBeMemberFunction)
18991901
printf("Should be member function\n");
@@ -1903,28 +1905,51 @@ void* ASTNode::generateCallExpression(int pass)
19031905
}
19041906
exit(1);
19051907
}
1908+
19061909
Function* CalleeF = CalleeFID->fnValue;
19071910
CalleeFID->uses++;
19081911

1909-
// If argument mismatch error.
1910-
if (CalleeFID->variableNumArguments == false)
1911-
if (CalleeF->arg_size() != args.size()) {
1912-
CalleeFID->print();
1913-
printTokenError(token, "Incorrect number of arguments passed to function", __LINE__);
1912+
1913+
bool isStructReturn = CalleeFID->isStructReturn;
1914+
AllocaInst* sretAlloc = nullptr;
1915+
if (isStructReturn) {
1916+
// Get the struct type from definitions
1917+
structType* retStruct = structDefinitions[CalleeFID->returnType];
1918+
if (retStruct->structVal == nullptr) {
1919+
printTokenError(token, "Struct return type not fully defined");
19141920
exit(1);
19151921
}
1916-
// If variable arguments, make sure the amount in call are <= the required amount
1917-
else if (CalleeF->arg_size() > args.size()) {
1922+
1923+
// Allocate space for the returned struct on the caller's stack
1924+
sretAlloc = Builder->CreateAlloca(retStruct->structVal, nullptr, "sret_alloc");
1925+
1926+
// Insert the sret pointer as the FIRST argument in ArgsV
1927+
ArgsV.insert(ArgsV.begin(), sretAlloc);
1928+
1929+
// For argList matching: Temporarily add sret to argList for validation
1930+
// (This matches how it's stored in functionID)
1931+
argType sretArg("*" + CalleeFID->returnType, Struct_Type, 1, false, true);
1932+
argList.insert(argList.begin(), sretArg);
1933+
}
1934+
1935+
// Validate argument count (now including sret if applicable)
1936+
if (CalleeFID->variableNumArguments == false) {
1937+
if (CalleeF->arg_size() != ArgsV.size()) { // Use ArgsV.size() which includes sret
1938+
printTokenError(token, "Incorrect number of arguments passed to function (expected " + std::to_string(CalleeF->arg_size()) + ")", __LINE__);
19181939
CalleeFID->print();
1919-
printTokenError(token, "Incorrect number of arguments passed to function", __LINE__);
19201940
exit(1);
19211941
}
1942+
}
1943+
else if (CalleeF->arg_size() > ArgsV.size()) {
1944+
printTokenError(token, "Incorrect number of arguments passed to function", __LINE__);
1945+
CalleeFID->print();
1946+
exit(1);
1947+
}
19221948

1923-
// Clear arg values list to get values correctly
1924-
ArgsV = std::vector<Value*>();
1925-
for (int i = 0; i < args.size(); i++) {
1926-
if (CalleeFID->arguments[i].isReference) {
1927-
if (args[0]->childNodes.size() != 1 || args[0]->childNodes[0]->nodeType != Identifier_Node) {
1949+
ArgsV.clear();
1950+
for (int i = 0; i < args.size(); i++) { // Start from caller's args (sret is already handled)
1951+
if (CalleeFID->arguments[i + (isStructReturn ? 1 : 0)].isReference) { // Offset by 1 if sret
1952+
if (args[i]->childNodes.size() != 1 || args[i]->childNodes[0]->nodeType != Identifier_Node) {
19281953
printTokenError(token, "Cannot pass value as reference");
19291954
exit(1);
19301955
}
@@ -1936,10 +1961,30 @@ void* ASTNode::generateCallExpression(int pass)
19361961
return nullptr;
19371962
}
19381963

1964+
// If struct return, re-insert sret as first arg (after rebuilding)
1965+
if (isStructReturn) {
1966+
ArgsV.insert(ArgsV.begin(), sretAlloc);
1967+
}
1968+
1969+
// Create the call (returns void for struct returns)
1970+
Value* callResult = Builder->CreateCall(CalleeF, ArgsV, "calltmp");
1971+
1972+
// For struct returns, return the loaded struct value (or pointer if lvalue)
1973+
if (isStructReturn) {
1974+
if (lvalue) {
1975+
return sretAlloc; // Return pointer for lvalue contexts (e.g., assignment)
1976+
}
1977+
else {
1978+
structType* retStruct = structDefinitions[CalleeFID->returnType];
1979+
return Builder->CreateLoad(retStruct->structVal, sretAlloc, "sret_load");
1980+
}
1981+
}
19391982

1940-
return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
1983+
// Non-struct: Return the call result directly
1984+
return callResult;
19411985
}
19421986

1987+
19431988
// Value*
19441989
void* ASTNode::generateIf(int pass)
19451990
{
@@ -2258,6 +2303,7 @@ void* ASTNode::generatePrototype(int pass)
22582303
Type* retType = Type::getVoidTy(*TheContext);
22592304
std::string rTypeString = "";
22602305
ASTNode* typeNode = childNodes[1];
2306+
bool isStructReturn = false;
22612307
if (typeNode->childNodes.size() > 0) {
22622308
recurseAddPointer:
22632309
typeNode = typeNode->childNodes[0];
@@ -2273,12 +2319,29 @@ void* ASTNode::generatePrototype(int pass)
22732319

22742320
bool wasDefined = true;
22752321
retType = getLLVMTypeFromString(rTypeString, 0, typeNode->token, wasDefined, pass);
2276-
//if (wasDefined == false)
2277-
// return nullptr;
2322+
2323+
// If the return type is a struct
2324+
if (retType && retType->isStructTy()) {
2325+
isStructReturn = true;
2326+
}
2327+
}
2328+
2329+
argumentList userArgList = argList;
2330+
2331+
// Handle struct return by modifying function signature
2332+
Type* actualRetType = retType;
2333+
if (isStructReturn) {
2334+
// For struct returns, add sret parameter as first argument
2335+
argTypes.push_back(retType->getPointerTo()); // sret parameter (pointer to struct)
2336+
argNames.push_back("sret");
2337+
argList.insert(argList.begin(), argType("*" + rTypeString, getASTNodeTypeFromString(rTypeString), 1, false, true));
2338+
2339+
// Change actual return type to void
2340+
actualRetType = Type::getVoidTy(*TheContext);
22782341
}
22792342

22802343
// Get function arguments
2281-
// If it is a struct, first add a "this" argument like: (this : ref structName, ...)
2344+
// If it is a struct member function, first add a "this" argument like: (this : ref structName, ...)
22822345
if (currentStructName.size() > 0) {
22832346
std::string typeStr = currentStructName.top();
22842347
bool isReference = true;
@@ -2351,6 +2414,7 @@ void* ASTNode::generatePrototype(int pass)
23512414

23522415
Type* aType = nullptr;
23532416
argList.push_back(argType(typeStr, getASTNodeTypeFromString(typeNode->token->first), pointerLevel, isReference, mustBeExactType));
2417+
userArgList.push_back(argType(typeStr, getASTNodeTypeFromString(typeNode->token->first), pointerLevel, isReference, mustBeExactType));
23542418

23552419
try {
23562420
bool wasDefined = true;
@@ -2394,7 +2458,7 @@ void* ASTNode::generatePrototype(int pass)
23942458

23952459
// Don't add another prototype if the exact same one is already defined
23962460
//Function* theFunction = TheModule->getFunction(token->first);
2397-
functionID* theFunctionID = getExactFunctionFromID(functionIDs, fnName, argList, token);
2461+
functionID* theFunctionID = getExactFunctionFromID(functionIDs, fnName, userArgList, token);
23982462
if (theFunctionID) {
23992463
if (verbosity >= 5) {
24002464
console::printIndent(2);
@@ -2406,7 +2470,7 @@ void* ASTNode::generatePrototype(int pass)
24062470
}
24072471

24082472

2409-
FunctionType* FT = FunctionType::get(retType, argTypes, variableNumArguments);
2473+
FunctionType* FT = FunctionType::get(actualRetType, argTypes, variableNumArguments);
24102474

24112475
Function* fn = nullptr;
24122476
// If extern declaration, dont mangle name
@@ -2423,7 +2487,7 @@ void* ASTNode::generatePrototype(int pass)
24232487
//namedValues[std::string(arg.getName())] = &arg;
24242488
}
24252489

2426-
functionIDs.push_back(new functionID(fnName, mangledName, rTypeString, argList, fn, variableNumArguments, isStruct));
2490+
functionIDs.push_back(new functionID(fnName, mangledName, rTypeString, argList, userArgList, fn, variableNumArguments, isStruct, isStructReturn));
24272491
if (verbosity >= 5) {
24282492
console::printIndent(depth + 2);
24292493
console::WriteLine("-- Added function \"" + fnName + "\" to functionIDs");

src/tokenizer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ std::map<const std::string, const TokenType> subTokenTypes = {
110110
{">=", Greater_Equal},
111111
{"|", Bar},
112112
{"||", Bar_Bar},
113+
{"|||", Bar_Bar_Bar},
113114
{"&", Ampersand},
114115
{"&&", Ampersand_Ampersand},
115116
{"~", Tilde},

src/tokenizer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ enum TokenType {
5151

5252
Bar,
5353
Bar_Bar,
54+
Bar_Bar_Bar,
5455
Ampersand,
5556
Ampersand_Ampersand,
5657
Tilde,
@@ -145,6 +146,7 @@ const std::string tokenTypeStrings[] = {
145146

146147
"Bar",
147148
"Bar_Bar",
149+
"Bar_Bar_Bar",
148150
"Ampersand",
149151
"Ampersand_Ampersand",
150152
"Tilde",

0 commit comments

Comments
 (0)