Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 112 additions & 52 deletions callback_query_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,83 +349,143 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface
sourceKeys = append(sourceKeys, key.DBName)
}

// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)
// get source foreign key field names
var foreignFieldNames []string
for _, dbName := range relation.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}

// generate query with join table
newScope := scope.New(reflect.New(fieldType).Interface())
preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value)
// get all source foreign key values
indirectScopeValue := scope.IndirectValue()
var sourceForeignKeys []interface{}
if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
sourceForeignKeys = append(sourceForeignKeys, getValueFromFields(object, foreignFieldNames))
}
} else if indirectScopeValue.IsValid() {
sourceForeignKeys = append(sourceForeignKeys, getValueFromFields(indirectScopeValue, foreignFieldNames))
}

if len(preloadDB.search.selects) == 0 {
preloadDB = preloadDB.Select("*")
if len(sourceForeignKeys) == 0 {
return
}

preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value)
// deduplicate source foreign keys to avoid redundant queries
uniqueSourceKeysMap := map[string]bool{}
for _, key := range sourceForeignKeys {
uniqueSourceKeysMap[toString(key)] = true
}

// preload inline conditions
if len(preloadConditions) > 0 {
preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
uniqueSourceKeyStrings := make([]string, 0, len(uniqueSourceKeysMap))
for key := range uniqueSourceKeysMap {
uniqueSourceKeyStrings = append(uniqueSourceKeyStrings, key)
}

rows, err := preloadDB.Rows()
// preload conditions
preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions)

if scope.Err(err) != nil {
return
}
defer rows.Close()
// need to query relations in chunk of 2000
// to avoid exceeding the mssql parameter limit of 2100
chunkSize := 2000
for chunkIdx := 0; chunkIdx < len(uniqueSourceKeyStrings); chunkIdx += chunkSize {
var sourceKeyChunk []string
if chunkSize > len(uniqueSourceKeyStrings)-chunkIdx {
sourceKeyChunk = uniqueSourceKeyStrings[chunkIdx:]
} else {
sourceKeyChunk = uniqueSourceKeyStrings[chunkIdx : chunkIdx+chunkSize]
}

columns, _ := rows.Columns()
for rows.Next() {
var (
elem = reflect.New(fieldType).Elem()
fields = scope.New(elem.Addr().Interface()).Fields()
)
// create a temporary slice containing only records in this chunk
chunkSourceKeysMap := map[string]bool{}
for _, key := range sourceKeyChunk {
chunkSourceKeysMap[key] = true
}

var chunkSourceValue reflect.Value
if indirectScopeValue.Kind() == reflect.Slice {
chunkSourceValue = reflect.MakeSlice(indirectScopeValue.Type(), 0, len(sourceKeyChunk))
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
key := toString(getValueFromFields(object, foreignFieldNames))
if chunkSourceKeysMap[key] {
chunkSourceValue = reflect.Append(chunkSourceValue, indirectScopeValue.Index(j))
}
}
} else {
chunkSourceValue = reflect.ValueOf(scope.Value)
}

// generate query with join table for this chunk
newScope := scope.New(reflect.New(fieldType).Interface())
chunkPreloadDB := preloadDB.Table(newScope.TableName()).Model(newScope.Value)

// register foreign keys in join tables
var joinTableFields []*Field
for _, sourceKey := range sourceKeys {
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
if len(chunkPreloadDB.search.selects) == 0 {
chunkPreloadDB = chunkPreloadDB.Select("*")
}

scope.scan(rows, columns, append(fields, joinTableFields...))
chunkPreloadDB = joinTableHandler.JoinWith(joinTableHandler, chunkPreloadDB, chunkSourceValue.Interface())

scope.New(elem.Addr().Interface()).
InstanceSet("gorm:skip_query_callback", true).
callCallbacks(scope.db.parent.callbacks.queries)
// preload inline conditions
if len(preloadConditions) > 0 {
chunkPreloadDB = chunkPreloadDB.Where(preloadConditions[0], preloadConditions[1:]...)
}

rows, err := chunkPreloadDB.Rows()

if scope.Err(err) != nil {
return
}

var foreignKeys = make([]interface{}, len(sourceKeys))
// generate hashed forkey keys in join table
for idx, joinTableField := range joinTableFields {
if !joinTableField.Field.IsNil() {
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
columns, _ := rows.Columns()
for rows.Next() {
var (
elem = reflect.New(fieldType).Elem()
fields = scope.New(elem.Addr().Interface()).Fields()
)

// register foreign keys in join tables
var joinTableFields []*Field
for _, sourceKey := range sourceKeys {
joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()})
}

scope.scan(rows, columns, append(fields, joinTableFields...))

scope.New(elem.Addr().Interface()).
InstanceSet("gorm:skip_query_callback", true).
callCallbacks(scope.db.parent.callbacks.queries)

var foreignKeys = make([]interface{}, len(sourceKeys))
// generate hashed forkey keys in join table
for idx, joinTableField := range joinTableFields {
if !joinTableField.Field.IsNil() {
foreignKeys[idx] = joinTableField.Field.Elem().Interface()
}
}
hashedSourceKeys := toString(foreignKeys)

if isPtr {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
} else {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
}
}
hashedSourceKeys := toString(foreignKeys)

if isPtr {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr())
} else {
linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem)
if err := rows.Err(); err != nil {
scope.Err(err)
}
}

if err := rows.Err(); err != nil {
scope.Err(err)
rows.Close()
}

// assign find results
var (
indirectScopeValue = scope.IndirectValue()
fieldsSourceMap = map[string][]reflect.Value{}
foreignFieldNames = []string{}
fieldsSourceMap = map[string][]reflect.Value{}
)

for _, dbName := range relation.ForeignFieldNames {
if field, ok := scope.FieldByName(dbName); ok {
foreignFieldNames = append(foreignFieldNames, field.Name)
}
}

if indirectScopeValue.Kind() == reflect.Slice {
for j := 0; j < indirectScopeValue.Len(); j++ {
object := indirect(indirectScopeValue.Index(j))
Expand Down