44import json
55
66from django import VERSION
7-
7+ from django . core import validators
88from django .db import NotSupportedError , connections , transaction
9- from django .db .models import BooleanField , Value
10- from django .db .models .functions import Cast , NthValue
11- from django .db .models .functions .math import ATan2 , Log , Ln , Mod , Round
12- from django .db .models .expressions import Case , Exists , OrderBy , When , Window , Expression
13- from django .db .models .lookups import Lookup , In
14- from django .db .models import lookups , CheckConstraint
9+ from django .db .models import BooleanField , CheckConstraint , Value
10+ from django .db .models .expressions import Case , Exists , Expression , OrderBy , When , Window
1511from django .db .models .fields import BinaryField , Field
16- from django .db .models .sql .query import Query
12+ from django .db .models .functions import Cast , NthValue
13+ from django .db .models .functions .math import ATan2 , Ln , Log , Mod , Round
14+ from django .db .models .lookups import In , Lookup
1715from django .db .models .query import QuerySet
18- from django .core import validators
16+ from django .db . models . sql . query import Query
1917
2018if VERSION >= (3 , 1 ):
2119 from django .db .models .fields .json import (
@@ -67,9 +65,11 @@ def sqlserver_nth_value(self, compiler, connection, **extra_content):
6765def sqlserver_round (self , compiler , connection , ** extra_context ):
6866 return self .as_sql (compiler , connection , template = '%(function)s(%(expressions)s, 0)' , ** extra_context )
6967
68+
7069def sqlserver_random (self , compiler , connection , ** extra_context ):
7170 return self .as_sql (compiler , connection , function = 'RAND' , ** extra_context )
7271
72+
7373def sqlserver_window (self , compiler , connection , template = None ):
7474 # MSSQL window functions require an OVER clause with ORDER BY
7575 if self .order_by is None :
@@ -125,6 +125,13 @@ def sqlserver_orderby(self, compiler, connection):
125125
126126
127127def split_parameter_list_as_sql (self , compiler , connection ):
128+ if connection .vendor == 'microsoft' :
129+ return mssql_split_parameter_list_as_sql (self , compiler , connection )
130+ else :
131+ return in_split_parameter_list_as_sql (self , compiler , connection )
132+
133+
134+ def mssql_split_parameter_list_as_sql (self , compiler , connection ):
128135 # Insert In clause parameters 1000 at a time into a temp table.
129136 lhs , _ = self .process_lhs (compiler , connection )
130137 _ , rhs_params = self .batch_process_rhs (compiler , connection )
@@ -143,26 +150,29 @@ def split_parameter_list_as_sql(self, compiler, connection):
143150
144151 return in_clause , ()
145152
153+
146154def unquote_json_rhs (rhs_params ):
147155 for value in rhs_params :
148156 value = json .loads (value )
149157 if not isinstance (value , (list , dict )):
150158 rhs_params = [param .replace ('"' , '' ) for param in rhs_params ]
151159 return rhs_params
152160
161+
153162def json_KeyTransformExact_process_rhs (self , compiler , connection ):
154- if isinstance (self .rhs , KeyTransform ):
155- return super (lookups .Exact , self ).process_rhs (compiler , connection )
156- rhs , rhs_params = super (KeyTransformExact , self ).process_rhs (compiler , connection )
163+ rhs , rhs_params = key_transform_exact_process_rhs (self , compiler , connection )
164+ if connection .vendor == 'microsoft' :
165+ rhs_params = unquote_json_rhs (rhs_params )
166+ return rhs , rhs_params
157167
158- return rhs , unquote_json_rhs (rhs_params )
159168
160169def json_KeyTransformIn (self , compiler , connection ):
161170 lhs , _ = super (KeyTransformIn , self ).process_lhs (compiler , connection )
162171 rhs , rhs_params = super (KeyTransformIn , self ).process_rhs (compiler , connection )
163172
164173 return (lhs + ' IN ' + rhs , unquote_json_rhs (rhs_params ))
165174
175+
166176def json_HasKeyLookup (self , compiler , connection ):
167177 # Process JSON path from the left-hand side.
168178 if isinstance (self .lhs , KeyTransform ):
@@ -193,6 +203,7 @@ def json_HasKeyLookup(self, compiler, connection):
193203
194204 return sql % tuple (rhs_params ), []
195205
206+
196207def BinaryField_init (self , * args , ** kwargs ):
197208 # Add max_length option for BinaryField, default to max
198209 kwargs .setdefault ('editable' , False )
@@ -202,6 +213,7 @@ def BinaryField_init(self, *args, **kwargs):
202213 else :
203214 self .max_length = 'max'
204215
216+
205217def _get_check_sql (self , model , schema_editor ):
206218 if VERSION >= (3 , 1 ):
207219 query = Query (model = model , alias_cols = False )
@@ -210,13 +222,16 @@ def _get_check_sql(self, model, schema_editor):
210222 where = query .build_where (self .check )
211223 compiler = query .get_compiler (connection = schema_editor .connection )
212224 sql , params = where .as_sql (compiler , schema_editor .connection )
213- try :
214- for p in params : str (p ).encode ('ascii' )
215- except UnicodeEncodeError :
216- sql = sql .replace ('%s' , 'N%s' )
225+ if schema_editor .connection .vendor == 'microsoft' :
226+ try :
227+ for p in params :
228+ str (p ).encode ('ascii' )
229+ except UnicodeEncodeError :
230+ sql = sql .replace ('%s' , 'N%s' )
217231
218232 return sql % tuple (schema_editor .quote_value (p ) for p in params )
219233
234+
220235def bulk_update_with_default (self , objs , fields , batch_size = None , default = 0 ):
221236 """
222237 Update the given fields in each of the given objects in the database.
@@ -255,10 +270,10 @@ def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
255270 attr = getattr (obj , field .attname )
256271 if not isinstance (attr , Expression ):
257272 if attr is None :
258- value_none_counter += 1
273+ value_none_counter += 1
259274 attr = Value (attr , output_field = field )
260275 when_statements .append (When (pk = obj .pk , then = attr ))
261- if ( value_none_counter == len (when_statements ) ):
276+ if connections [ self . db ]. vendor == 'microsoft' and value_none_counter == len (when_statements ):
262277 case_statement = Case (* when_statements , output_field = field , default = Value (default ))
263278 else :
264279 case_statement = Case (* when_statements , output_field = field )
@@ -272,10 +287,15 @@ def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
272287 rows_updated += self .filter (pk__in = pks ).update (** update_kwargs )
273288 return rows_updated
274289
290+
275291ATan2 .as_microsoft = sqlserver_atan2
292+ # Need copy of old In.split_parameter_list_as_sql for other backends to call
293+ in_split_parameter_list_as_sql = In .split_parameter_list_as_sql
276294In .split_parameter_list_as_sql = split_parameter_list_as_sql
277295if VERSION >= (3 , 1 ):
278296 KeyTransformIn .as_microsoft = json_KeyTransformIn
297+ # Need copy of old KeyTransformExact.process_rhs to call later
298+ key_transform_exact_process_rhs = KeyTransformExact .process_rhs
279299 KeyTransformExact .process_rhs = json_KeyTransformExact_process_rhs
280300 HasKeyLookup .as_microsoft = json_HasKeyLookup
281301Ln .as_microsoft = sqlserver_ln
0 commit comments