@@ -68,20 +68,24 @@ struct functionID {
6868 std::string mangledName = " " ;
6969 std::string returnType = " " ;
7070 argumentList arguments = argumentList();
71+ argumentList userArguments = argumentList();
7172 bool variableNumArguments = false ;
73+ bool isStructReturn = false ;
7274 uint32_t uses = 0 ;
7375 bool isMemberFunction = false ;
7476 Function* fnValue = nullptr ;
7577 functionID () {}
76- functionID (std::string n, std::string mN , std::string r, argumentList a, Function* f, bool vA = false , bool mF = false )
78+ functionID (std::string n, std::string mN , std::string r, argumentList llvmArgs, argumentList userArgs, Function* f, bool vA = false , bool mF = false , bool sRet = false )
7779 {
7880 name = n;
7981 mangledName = mN ;
8082 returnType = r;
81- arguments = a;
83+ arguments = llvmArgs;
84+ userArguments = userArgs;
8285 fnValue = f;
8386 variableNumArguments = vA;
8487 isMemberFunction = mF ;
88+ isStructReturn = sRet ;
8589 }
8690 void print ()
8791 {
@@ -123,12 +127,12 @@ struct functionID {
123127 uint16_t differences = 0 ;
124128 if (n != name)
125129 return 1000 ;
126- if (arguments .size () != a.size ())
130+ if (userArguments .size () != a.size ())
127131 return 1000 - 1 ;
128- for (int i = 0 ; i < arguments .size (); i++) {
129- ASTNodeType t1 = arguments [i].baseASTType ;
132+ for (int i = 0 ; i < userArguments .size (); i++) {
133+ ASTNodeType t1 = userArguments [i].baseASTType ;
130134 ASTNodeType t2 = a[i].baseASTType ;
131- bool mustBeExactType = arguments [i].mustBeExactType ;
135+ bool mustBeExactType = userArguments [i].mustBeExactType ;
132136 // If t1 is an integer type, make sure t2 is also
133137 // Difference points are given the further the types are
134138
@@ -163,12 +167,12 @@ struct functionID {
163167 uint16_t differences = 0 ;
164168 if (n != name)
165169 return 1000 ;
166- if (arguments .size () != a.size ())
170+ if (userArguments .size () != a.size ())
167171 return 1000 - 1 ;
168- for (int i = 0 ; i < arguments .size (); i++) {
169- ASTNodeType t1 = arguments [i].baseASTType ;
172+ for (int i = 0 ; i < userArguments .size (); i++) {
173+ ASTNodeType t1 = userArguments [i].baseASTType ;
170174 ASTNodeType t2 = a[i]->nodeType ;
171- bool mustBeExactType = arguments [i].mustBeExactType ;
175+ bool mustBeExactType = userArguments [i].mustBeExactType ;
172176 // If t1 is an integer type, make sure t2 is also
173177 // Difference points are given the further the types are
174178
@@ -841,8 +845,6 @@ void* ASTNode::generateConstant(int pass)
841845 return ConstantFP::get (*TheContext, APFloat (stod (token->first )));
842846 else if (nodeType == String_Constant_Node) {
843847 std::string strValue = unescapeString (token->first .substr (1 , token->first .size () - 2 ), token); // remove quotes from token
844- // Add null terminator
845- strValue += ' \0 ' ;
846848
847849 GlobalVariable* globalStr = nullptr ;
848850 if (globalStringLiteralConstants.find (strValue) != globalStringLiteralConstants.end ())
@@ -1873,15 +1875,17 @@ void* ASTNode::generateCallExpression(int pass)
18731875 ASTNode* argsNode = childNodes[0 ];
18741876 bool shouldBeMemberFunction = isCallMemberFunction;
18751877 isCallMemberFunction = false ;
1878+
18761879 std::vector<ASTNode*> args = std::vector<ASTNode*>();
18771880 for (auto & a : argsNode->childNodes )
18781881 if (a->childNodes .size () > 0 ) {
18791882 args.push_back (a);
1880- // aList.push_back(std::make_pair());
18811883 }
18821884
18831885 std::vector<Value*> ArgsV = std::vector<Value*>();
18841886 argumentList argList = argumentList ();
1887+
1888+ // Build argList and ArgsV WITHOUT sret initially
18851889 for (int i = 0 ; i < args.size (); i++) {
18861890 Value* argVal = (Value*)(args[i]->*(args[i]->codegen ))(pass);
18871891 ArgsV.push_back (argVal);
@@ -1890,10 +1894,8 @@ void* ASTNode::generateCallExpression(int pass)
18901894 return nullptr ;
18911895 }
18921896
1893-
1894- // Look up the id in the global module table.
1897+ // Look up the function ID using the caller's argList (without sret)
18951898 functionID* CalleeFID = getFunctionFromID (functionIDs, token->first , argList, token, true , shouldBeMemberFunction);
1896- // Function* CalleeF = TheModule->getFunction(token->first);
18971899 if (!CalleeFID) {
18981900 if (shouldBeMemberFunction)
18991901 printf (" Should be member function\n " );
@@ -1903,28 +1905,51 @@ void* ASTNode::generateCallExpression(int pass)
19031905 }
19041906 exit (1 );
19051907 }
1908+
19061909 Function* CalleeF = CalleeFID->fnValue ;
19071910 CalleeFID->uses ++;
19081911
1909- // If argument mismatch error.
1910- if (CalleeFID->variableNumArguments == false )
1911- if (CalleeF->arg_size () != args.size ()) {
1912- CalleeFID->print ();
1913- printTokenError (token, " Incorrect number of arguments passed to function" , __LINE__);
1912+
1913+ bool isStructReturn = CalleeFID->isStructReturn ;
1914+ AllocaInst* sretAlloc = nullptr ;
1915+ if (isStructReturn) {
1916+ // Get the struct type from definitions
1917+ structType* retStruct = structDefinitions[CalleeFID->returnType ];
1918+ if (retStruct->structVal == nullptr ) {
1919+ printTokenError (token, " Struct return type not fully defined" );
19141920 exit (1 );
19151921 }
1916- // If variable arguments, make sure the amount in call are <= the required amount
1917- else if (CalleeF->arg_size () > args.size ()) {
1922+
1923+ // Allocate space for the returned struct on the caller's stack
1924+ sretAlloc = Builder->CreateAlloca (retStruct->structVal , nullptr , " sret_alloc" );
1925+
1926+ // Insert the sret pointer as the FIRST argument in ArgsV
1927+ ArgsV.insert (ArgsV.begin (), sretAlloc);
1928+
1929+ // For argList matching: Temporarily add sret to argList for validation
1930+ // (This matches how it's stored in functionID)
1931+ argType sretArg (" *" + CalleeFID->returnType , Struct_Type, 1 , false , true );
1932+ argList.insert (argList.begin (), sretArg);
1933+ }
1934+
1935+ // Validate argument count (now including sret if applicable)
1936+ if (CalleeFID->variableNumArguments == false ) {
1937+ if (CalleeF->arg_size () != ArgsV.size ()) { // Use ArgsV.size() which includes sret
1938+ printTokenError (token, " Incorrect number of arguments passed to function (expected " + std::to_string (CalleeF->arg_size ()) + " )" , __LINE__);
19181939 CalleeFID->print ();
1919- printTokenError (token, " Incorrect number of arguments passed to function" , __LINE__);
19201940 exit (1 );
19211941 }
1942+ }
1943+ else if (CalleeF->arg_size () > ArgsV.size ()) {
1944+ printTokenError (token, " Incorrect number of arguments passed to function" , __LINE__);
1945+ CalleeFID->print ();
1946+ exit (1 );
1947+ }
19221948
1923- // Clear arg values list to get values correctly
1924- ArgsV = std::vector<Value*>();
1925- for (int i = 0 ; i < args.size (); i++) {
1926- if (CalleeFID->arguments [i].isReference ) {
1927- if (args[0 ]->childNodes .size () != 1 || args[0 ]->childNodes [0 ]->nodeType != Identifier_Node) {
1949+ ArgsV.clear ();
1950+ for (int i = 0 ; i < args.size (); i++) { // Start from caller's args (sret is already handled)
1951+ if (CalleeFID->arguments [i + (isStructReturn ? 1 : 0 )].isReference ) { // Offset by 1 if sret
1952+ if (args[i]->childNodes .size () != 1 || args[i]->childNodes [0 ]->nodeType != Identifier_Node) {
19281953 printTokenError (token, " Cannot pass value as reference" );
19291954 exit (1 );
19301955 }
@@ -1936,10 +1961,30 @@ void* ASTNode::generateCallExpression(int pass)
19361961 return nullptr ;
19371962 }
19381963
1964+ // If struct return, re-insert sret as first arg (after rebuilding)
1965+ if (isStructReturn) {
1966+ ArgsV.insert (ArgsV.begin (), sretAlloc);
1967+ }
1968+
1969+ // Create the call (returns void for struct returns)
1970+ Value* callResult = Builder->CreateCall (CalleeF, ArgsV, " calltmp" );
1971+
1972+ // For struct returns, return the loaded struct value (or pointer if lvalue)
1973+ if (isStructReturn) {
1974+ if (lvalue) {
1975+ return sretAlloc; // Return pointer for lvalue contexts (e.g., assignment)
1976+ }
1977+ else {
1978+ structType* retStruct = structDefinitions[CalleeFID->returnType ];
1979+ return Builder->CreateLoad (retStruct->structVal , sretAlloc, " sret_load" );
1980+ }
1981+ }
19391982
1940- return Builder->CreateCall (CalleeF, ArgsV, " calltmp" );
1983+ // Non-struct: Return the call result directly
1984+ return callResult;
19411985}
19421986
1987+
19431988// Value*
19441989void * ASTNode::generateIf (int pass)
19451990{
@@ -2258,6 +2303,7 @@ void* ASTNode::generatePrototype(int pass)
22582303 Type* retType = Type::getVoidTy (*TheContext);
22592304 std::string rTypeString = " " ;
22602305 ASTNode* typeNode = childNodes[1 ];
2306+ bool isStructReturn = false ;
22612307 if (typeNode->childNodes .size () > 0 ) {
22622308 recurseAddPointer:
22632309 typeNode = typeNode->childNodes [0 ];
@@ -2273,12 +2319,29 @@ void* ASTNode::generatePrototype(int pass)
22732319
22742320 bool wasDefined = true ;
22752321 retType = getLLVMTypeFromString (rTypeString, 0 , typeNode->token , wasDefined, pass);
2276- // if (wasDefined == false)
2277- // return nullptr;
2322+
2323+ // If the return type is a struct
2324+ if (retType && retType->isStructTy ()) {
2325+ isStructReturn = true ;
2326+ }
2327+ }
2328+
2329+ argumentList userArgList = argList;
2330+
2331+ // Handle struct return by modifying function signature
2332+ Type* actualRetType = retType;
2333+ if (isStructReturn) {
2334+ // For struct returns, add sret parameter as first argument
2335+ argTypes.push_back (retType->getPointerTo ()); // sret parameter (pointer to struct)
2336+ argNames.push_back (" sret" );
2337+ argList.insert (argList.begin (), argType (" *" + rTypeString, getASTNodeTypeFromString (rTypeString), 1 , false , true ));
2338+
2339+ // Change actual return type to void
2340+ actualRetType = Type::getVoidTy (*TheContext);
22782341 }
22792342
22802343 // Get function arguments
2281- // If it is a struct, first add a "this" argument like: (this : ref structName, ...)
2344+ // If it is a struct member function , first add a "this" argument like: (this : ref structName, ...)
22822345 if (currentStructName.size () > 0 ) {
22832346 std::string typeStr = currentStructName.top ();
22842347 bool isReference = true ;
@@ -2351,6 +2414,7 @@ void* ASTNode::generatePrototype(int pass)
23512414
23522415 Type* aType = nullptr ;
23532416 argList.push_back (argType (typeStr, getASTNodeTypeFromString (typeNode->token ->first ), pointerLevel, isReference, mustBeExactType));
2417+ userArgList.push_back (argType (typeStr, getASTNodeTypeFromString (typeNode->token ->first ), pointerLevel, isReference, mustBeExactType));
23542418
23552419 try {
23562420 bool wasDefined = true ;
@@ -2394,7 +2458,7 @@ void* ASTNode::generatePrototype(int pass)
23942458
23952459 // Don't add another prototype if the exact same one is already defined
23962460 // Function* theFunction = TheModule->getFunction(token->first);
2397- functionID* theFunctionID = getExactFunctionFromID (functionIDs, fnName, argList , token);
2461+ functionID* theFunctionID = getExactFunctionFromID (functionIDs, fnName, userArgList , token);
23982462 if (theFunctionID) {
23992463 if (verbosity >= 5 ) {
24002464 console::printIndent (2 );
@@ -2406,7 +2470,7 @@ void* ASTNode::generatePrototype(int pass)
24062470 }
24072471
24082472
2409- FunctionType* FT = FunctionType::get (retType , argTypes, variableNumArguments);
2473+ FunctionType* FT = FunctionType::get (actualRetType , argTypes, variableNumArguments);
24102474
24112475 Function* fn = nullptr ;
24122476 // If extern declaration, dont mangle name
@@ -2423,7 +2487,7 @@ void* ASTNode::generatePrototype(int pass)
24232487 // namedValues[std::string(arg.getName())] = &arg;
24242488 }
24252489
2426- functionIDs.push_back (new functionID (fnName, mangledName, rTypeString, argList, fn, variableNumArguments, isStruct));
2490+ functionIDs.push_back (new functionID (fnName, mangledName, rTypeString, argList, userArgList, fn, variableNumArguments, isStruct, isStructReturn ));
24272491 if (verbosity >= 5 ) {
24282492 console::printIndent (depth + 2 );
24292493 console::WriteLine (" -- Added function \" " + fnName + " \" to functionIDs" );
0 commit comments