Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions rocksdb/_rocksdb.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ from . cimport env
from . cimport table_factory
from . cimport memtablerep
from . cimport universal_compaction
from . cimport transaction_db

# Enums are the only exception for direct imports
# Their name als already unique enough
Expand Down Expand Up @@ -1499,6 +1500,87 @@ cdef class Options(ColumnFamilyOptions):
self.py_row_cache = value
self.opts.row_cache = self.py_row_cache.get_cache()


cdef class TransactionDBOptions(object):
cdef transaction_db.TransactionDBOptions* opts

def __cinit__(self):
self.opts = new transaction_db.TransactionDBOptions()

def __dealloc__(self):
if not self.opts == NULL:
del self.opts

def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)

property max_num_locks:
def __get__(self):
return self.opts.max_num_locks
def __set__(self, value):
self.opts.max_num_locks = value

property max_num_deadlocks:
def __get__(self):
return self.opts.max_num_deadlocks
def __set__(self, value):
self.opts.max_num_deadlocks = value

property num_stripes:
def __get__(self):
return self.opts.num_stripes
def __set__(self, value):
self.opts.num_stripes = value

property transaction_lock_timeout:
def __get__(self):
return self.opts.transaction_lock_timeout
def __set__(self, value):
self.opts.transaction_lock_timeout = value

property default_lock_timeout:
def __get__(self):
return self.opts.default_lock_timeout
def __set__(self, value):
self.opts.default_lock_timeout = value

# TODO property custom_mutex_factory
property write_policy:
def __get__(self):
if self.opts.write_policy == transaction_db.WRITE_COMMITTED:
return 'write_committed'
if self.opts.write_policy == transaction_db.WRITE_PREPARED:
return 'write_prepared'
if self.opts.write_policy == transaction_db.WRITE_UNPREPARED:
return 'write_unprepared'
raise Exception("Unknown write policy")

def __set__(self, str value):
if value == 'write_committed':
self.opts.write_policy = transaction_db.WRITE_COMMITTED
elif value == 'write_prepared':
self.opts.write_policy = transaction_db.WRITE_PREPARED
elif value == 'write_unprepared':
self.opts.write_policy = transaction_db.WRITE_UNPREPARED
else:
raise Exception("Unknown write policy")

property rollback_merge_operands:
def __get__(self):
return self.opts.rollback_merge_operands
def __set__(self, value):
self.opts.rollback_merge_operands = value
property skip_concurrency_control:
def __get__(self):
return self.opts.skip_concurrency_control
def __set__(self, value):
self.opts.skip_concurrency_control = value
property default_write_batch_flush_threshold:
def __get__(self):
return self.opts.default_write_batch_flush_threshold
def __set__(self, value):
self.opts.default_write_batch_flush_threshold = value

# Forward declaration
cdef class Snapshot
Expand Down Expand Up @@ -2271,6 +2353,159 @@ def list_column_families(db_name, Options opts):

return column_families

@cython.no_gc_clear
cdef class TransactionDB(object):
cdef Options opts
cdef transaction_db.TransactionDB* db
cdef list cf_handles
cdef list cf_options

def __cinit__(self, db_name, Options opts, TransactionDBOptions tdb_opts, dict column_families=None):
cdef Status st
cdef string db_path
cdef vector[db.ColumnFamilyDescriptor] column_family_descriptors
cdef vector[db.ColumnFamilyHandle*] column_family_handles
cdef bytes default_cf_name = db.kDefaultColumnFamilyName
self.db = NULL
self.opts = None
self.cf_handles = []
self.cf_options = []

if opts.in_use:
raise Exception("Options object is already used by another DB")

db_path = path_to_string(db_name)
if not column_families or default_cf_name not in column_families:
# Always add the default column family
column_family_descriptors.push_back(
db.ColumnFamilyDescriptor(
db.kDefaultColumnFamilyName,
options.ColumnFamilyOptions(deref(opts.opts))
)
)
self.cf_options.append(None) # Since they are the same as db
if column_families:
for cf_name, cf_options in column_families.items():
if not isinstance(cf_name, bytes):
raise TypeError(
f"column family name {cf_name!r} is not of type {bytes}!"
)
if not isinstance(cf_options, ColumnFamilyOptions):
raise TypeError(
f"column family options {cf_options!r} is not of type "
f"{ColumnFamilyOptions}!"
)
if (<ColumnFamilyOptions>cf_options).in_use:
raise Exception(
f"ColumnFamilyOptions object for {cf_name} is already "
"used by another Column Family"
)
(<ColumnFamilyOptions>cf_options).in_use = True
column_family_descriptors.push_back(
db.ColumnFamilyDescriptor(
cf_name,
deref((<ColumnFamilyOptions>cf_options).copts)
)
)
self.cf_options.append(cf_options)
if column_families:
for cf_name, cf_options in column_families.items():
if not isinstance(cf_name, bytes):
raise TypeError(
f"column family name {cf_name!r} is not of type {bytes}!"
)
if not isinstance(cf_options, ColumnFamilyOptions):
raise TypeError(
f"column family options {cf_options!r} is not of type "
f"{ColumnFamilyOptions}!"
)
if (<ColumnFamilyOptions>cf_options).in_use:
raise Exception(
f"ColumnFamilyOptions object for {cf_name} is already "
"used by another Column Family"
)
(<ColumnFamilyOptions>cf_options).in_use = True
column_family_descriptors.push_back(
db.ColumnFamilyDescriptor(
cf_name,
deref((<ColumnFamilyOptions>cf_options).copts)
)
)
self.cf_options.append(cf_options)

with nogil:
st = transaction_db.TransactionDB_Open_ColumnFamilies(
deref(opts.opts),
deref(tdb_opts.opts),
db_path,
column_family_descriptors,
&column_family_handles,
&self.db)
check_status(st)

for handle in column_family_handles:
wrapper = _ColumnFamilyHandle.from_handle_ptr(handle)
self.cf_handles.append(wrapper)

# Inject the loggers into the python callbacks
cdef shared_ptr[logger.Logger] info_log = self.db.GetOptions(
self.db.DefaultColumnFamily()).info_log
if opts.py_comparator is not None:
opts.py_comparator.set_info_log(info_log)

if opts.py_table_factory is not None:
opts.py_table_factory.set_info_log(info_log)

if opts.prefix_extractor is not None:
opts.py_prefix_extractor.set_info_log(info_log)

cdef ColumnFamilyOptions copts
for idx, copts in enumerate(self.cf_options):
if not copts:
continue

info_log = self.db.GetOptions(column_family_handles[idx]).info_log

if copts.py_comparator is not None:
copts.py_comparator.set_info_log(info_log)

if copts.py_table_factory is not None:
copts.py_table_factory.set_info_log(info_log)

if copts.prefix_extractor is not None:
copts.py_prefix_extractor.set_info_log(info_log)

self.opts = opts
self.opts.in_use = True

def close(self, safe=True):
cdef ColumnFamilyOptions copts
cdef cpp_bool c_safe = safe
cdef Status st
if self.db != NULL:
# We need stop backround compactions
with nogil:
db.CancelAllBackgroundWork(self.db, c_safe)
# We have to make sure we delete the handles so rocksdb doesn't
# assert when we delete the db
del self.cf_handles[:]
for copts in self.cf_options:
if copts:
copts.in_use = False
del self.cf_options[:]
with nogil:
st = self.db.Close()
self.db = NULL
if self.opts is not None:
self.opts.in_use = False

def __dealloc__(self):
self.close()

property options:
def __get__(self):
return self.opts


@cython.no_gc_clear
@cython.internal
Expand Down
26 changes: 26 additions & 0 deletions rocksdb/stackable_db.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from . cimport options
from libc.stdint cimport uint64_t, uint32_t
from .status cimport Status
from libcpp cimport bool as cpp_bool
from libcpp.string cimport string
from libcpp.vector cimport vector
from libcpp.map cimport map
from libcpp.unordered_map cimport unordered_map
from libcpp.memory cimport shared_ptr
from .types cimport SequenceNumber
from .slice_ cimport Slice
from .snapshot cimport Snapshot
from .iterator cimport Iterator
from .env cimport Env
from .metadata cimport ColumnFamilyMetaData
from .metadata cimport LiveFileMetaData
from .metadata cimport ExportImportFilesMetaData
from .table_properties cimport TableProperties
from .db cimport DB

cdef extern from "rocksdb/utilities/stackable_db.h" namespace "rocksdb":
cdef cppclass StackableDB(DB):
StackableDB(DB*) nogil except+
StackableDB(shared_ptr[DB] db) nogil except+
DB* GetBaseDB() nogil except+

94 changes: 94 additions & 0 deletions rocksdb/tests/test_stackable_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os
import sys
import shutil
import gc
import unittest
import rocksdb
from itertools import takewhile
import struct
import tempfile
from rocksdb.merge_operators import UintAddOperator, StringAppendOperator

from .test_db import TestHelper

class TestStackableDB(TestHelper):
def setUp(self):
TestHelper.setUp(self)
opts = rocksdb.Options(create_if_missing=True)
self.db = rocksdb.StackableDB(os.path.join(self.db_loc, "test"), opts)

def test_options_used_twice(self):
if sys.version_info[0] == 3:
assertRaisesRegex = self.assertRaisesRegex
else:
assertRaisesRegex = self.assertRaisesRegexp
expected = "Options object is already used by another DB"
with assertRaisesRegex(Exception, expected):
rocksdb.DB(os.path.join(self.db_loc, "test2"), self.db.options)

def test_unicode_path(self):
name = os.path.join(self.db_loc, b'M\xc3\xbcnchen'.decode('utf8'))
rocksdb.DB(name, rocksdb.Options(create_if_missing=True))
self.addCleanup(shutil.rmtree, name)
self.assertTrue(os.path.isdir(name))

def test_get_none(self):
self.assertIsNone(self.db.get(b'xxx'))

def test_put_get(self):
self.db.put(b"a", b"b")
self.assertEqual(b"b", self.db.get(b"a"))

def test_multi_get(self):
self.db.put(b"a", b"1")
self.db.put(b"b", b"2")
self.db.put(b"c", b"3")

ret = self.db.multi_get([b'a', b'b', b'c'])
ref = {b'a': b'1', b'c': b'3', b'b': b'2'}
self.assertEqual(ref, ret)

def test_delete(self):
self.db.put(b"a", b"b")
self.assertEqual(b"b", self.db.get(b"a"))
self.db.delete(b"a")
self.assertIsNone(self.db.get(b"a"))

def test_write_batch(self):
batch = rocksdb.WriteBatch()
batch.put(b"key", b"v1")
batch.delete(b"key")
batch.put(b"key", b"v2")
batch.put(b"key", b"v3")
batch.put(b"a", b"b")

self.db.write(batch)
ref = {b'a': b'b', b'key': b'v3'}
ret = self.db.multi_get([b'key', b'a'])
self.assertEqual(ref, ret)

def test_write_batch_iter(self):
batch = rocksdb.WriteBatch()
self.assertEqual([], list(batch))

batch.put(b"key1", b"v1")
batch.put(b"key2", b"v2")
batch.put(b"key3", b"v3")
batch.delete(b'a')
batch.delete(b'key1')
batch.merge(b'xxx', b'value')

it = iter(batch)
del batch
ref = [
('Put', b'key1', b'v1'),
('Put', b'key2', b'v2'),
('Put', b'key3', b'v3'),
('Delete', b'a', b''),
('Delete', b'key1', b''),
('Merge', b'xxx', b'value')
]
self.assertEqual(ref, list(it))



Loading