Skip to content

Commit 3b6e6a9

Browse files
committed
Allow to INSERT multiple records
1 parent aa07762 commit 3b6e6a9

File tree

6 files changed

+47
-24
lines changed

6 files changed

+47
-24
lines changed

beanquery/compiler.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -764,21 +764,24 @@ def _insert(self, node: ast.Insert):
764764
impl = getattr(table, 'insert', None)
765765
if impl is None:
766766
raise CompilationError(f'table "{node.table.name}" does not support insertion', node.table)
767-
if len(node.values) != len(node.columns):
768-
raise CompilationError(
769-
f'column names and values mismatch: '
770-
f'expected {len(node.columns)} but {len(node.values)} values were supplied', node)
771-
values = [EvalConstant(None)] * len(table.columns)
772767
columns = {name: i for i, name in enumerate(table.columns.keys())}
773-
for column, value in zip(node.columns, node.values):
774-
index = columns.get(column.name)
775-
if index is None:
776-
raise CompilationError(f'column "{column.name}" not found in table "{node.table.name}"', column)
777-
expr = self._compile(value)
778-
if not expr.dtype == table.columns.get(column.name).dtype:
779-
raise CompilationError(f'expression has wrong type for column "{column.name}"', value)
780-
values[index] = expr
781-
return EvalInsert(table, values)
768+
rows = []
769+
for row in node.values:
770+
if len(row) != len(node.columns):
771+
raise CompilationError(
772+
f'column names and values mismatch: '
773+
f'expected {len(node.columns)} but {len(row)} values were supplied', node)
774+
values = [EvalConstant(None)] * len(table.columns)
775+
for column, value in zip(node.columns, row):
776+
index = columns.get(column.name)
777+
if index is None:
778+
raise CompilationError(f'column "{column.name}" not found in table "{node.table.name}"', column)
779+
expr = self._compile(value)
780+
if not expr.dtype == table.columns.get(column.name).dtype:
781+
raise CompilationError(f'expression has wrong type for column "{column.name}"', value)
782+
values[index] = expr
783+
rows.append(values)
784+
return EvalInsert(table, rows)
782785

783786

784787
def transform_journal(journal):

beanquery/parser/bql.ebnf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,5 +397,5 @@ create_table::CreateTable
397397
insert::Insert
398398
= 'INSERT' 'INTO' ~ table:table
399399
['(' columns:','.{column} ')']
400-
'VALUES' '(' values:','.{expression} ')'
400+
'VALUES' ','.{ '(' values+:','.{expression}+ ')' }
401401
;

beanquery/parser/parser.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,17 +1256,30 @@ def block1():
12561256
self._token(')')
12571257
self._define(['columns'], [])
12581258
self._token('VALUES')
1259-
self._token('(')
12601259

12611260
def sep2():
12621261
self._token(',')
12631262

12641263
def block3():
1265-
self._expression_()
1264+
self._token('(')
1265+
1266+
def sep4():
1267+
self._token(',')
1268+
1269+
def block5():
1270+
self._expression_()
1271+
self._positive_gather(block5, sep4)
1272+
self.add_last_node_to_name('values')
1273+
self._token(')')
1274+
self._define(
1275+
[],
1276+
['values'],
1277+
)
12661278
self._gather(block3, sep2)
1267-
self.name_last_node('values')
1268-
self._token(')')
1269-
self._define(['columns', 'table', 'values'], [])
1279+
self._define(
1280+
['columns', 'table'],
1281+
['values'],
1282+
)
12701283

12711284

12721285
def main(filename, **kwargs):

beanquery/query_compile.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import operator
1919

2020
from decimal import Decimal
21-
from typing import List
21+
from typing import List, Sequence
2222

2323
from dateutil.relativedelta import relativedelta
2424

@@ -697,9 +697,9 @@ def __call__(self):
697697
@dataclasses.dataclass
698698
class EvalInsert:
699699
table: tables.Table
700-
values: list[EvalNode]
700+
rows: Sequence[Sequence[EvalNode]]
701701

702702
def __call__(self):
703-
values = tuple(value(None) for value in self.values)
704-
self.table.insert(values)
703+
for row in self.rows:
704+
self.table.insert(tuple(value(None) for value in row))
705705
return (), []

beanquery/query_execute_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,12 @@ def test_insert_placeholders(self):
18181818
self.assertEqual(self.conn.tables['abcd'].data[0], values)
18191819
self.assertEqual(curs.fetchall(), [])
18201820

1821+
def test_insert_many(self):
1822+
curs = self.conn.execute('''INSERT INTO abcd (a) VALUES (1), (2), (3), (4)''')
1823+
values = [row[0] for row in self.conn.tables['abcd'].data]
1824+
self.assertEqual(values, [1, 2, 3, 4])
1825+
self.assertEqual(curs.fetchall(), [])
1826+
18211827

18221828
class TestCSVTable(unittest.TestCase):
18231829

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ ignore = [
8484
'PLW2901',
8585
'RUF012',
8686
'RUF023', # unsorted-dunder-slots
87+
'RUF059', # unused-unpacked-variable
8788
'UP007',
8889
'UP032',
8990
]

0 commit comments

Comments
 (0)