@@ -464,33 +464,38 @@ def __call__(self, values: dict[Domain, float | int], method="cog") -> np.floati
464464 """Calculate the infered value based on different methods.
465465 Default is center of gravity (cog).
466466 """
467- assert len (args ) == max (
468- len (c ) for c in self .conditions .keys ()
469- ), "Number of values must correspond to the number of domains defined as conditions!"
470- assert isinstance (args , dict ), "Please make sure to pass in the values as a dictionary."
467+ assert isinstance (values , dict ), "Please make sure to pass a dict[Domain, float|int] as values."
468+ assert len (self .conditions ) > 0 , "No point in having a rule with no conditions, is there?"
471469 match method :
472470 case "cog" :
473- assert (
474- len ({C .domain for C in self .conditions .values ()}) == 1
475- ), "For CoG, all conditions must have the same target domain."
476- actual_values = {f : f (args [f .domain ]) for S in self .conditions .keys () for f in S }
477-
478- weights = []
479- for K , v in self .conditions .items ():
480- x = min ((actual_values [k ] for k in K if k in actual_values ), default = 0 )
471+ # iterate over the conditions and calculate the actual values and weights contributing to cog
472+ target_weights : list [tuple [Set , Number ]] = []
473+ target_domain = list (self .conditions .values ())[0 ].domain
474+ assert target_domain is not None , "Target domain must be defined."
475+ for if_sets , then_set in self .conditions .items ():
476+ actual_values : list [Number ] = []
477+ assert then_set .domain == target_domain , "All target sets must be in the same Domain."
478+ for s in if_sets :
479+ assert s .domain is not None , "Domains must be defined."
480+ actual_values .append (s (values [s .domain ]))
481+ x = min (actual_values , default = 0 )
481482 if x > 0 :
482- weights .append ((v , x ))
483-
484- if not weights :
483+ target_weights .append ((then_set , x ))
484+ if not target_weights :
485485 return None
486- target_domain = list (self .conditions .values ())[0 ].domain
487- index = sum (v .center_of_gravity * x for v , x in weights ) / sum (x for v , x in weights )
486+ sum_weights = 0
487+ sum_weighted_cogs = 0
488+ for then_set , weight in target_weights :
489+ sum_weighted_cogs += then_set .center_of_gravity () * weight
490+ sum_weights += weight
491+ index = sum_weighted_cogs / sum_weights
492+
488493 return (target_domain ._high - target_domain ._low ) / len (
489494 target_domain .range
490495 ) * index + target_domain ._low
491496
492- case "centroid" :
493- raise NotImplementedError ("Centroid method not implemented yet ." )
497+ case "centroid" : # centroid == center of mass == center of gravity for simple solids
498+ raise NotImplementedError ("actually the same as 'cog' if densities are uniform ." )
494499 case "bisector" :
495500 raise NotImplementedError ("Bisector method not implemented yet." )
496501 case "mom" :
@@ -529,7 +534,7 @@ def rule_from_table(table: str, references: dict):
529534 ): eval (df .iloc [x , y ], references ) # type: ignore
530535 for x , y in product (range (len (df .index )), range (len (df .columns )))
531536 }
532- return Rule (D )
537+ return Rule (D ) # type: ignore
533538
534539
535540if __name__ == "__main__" :
0 commit comments