diff --git a/.cirrus.yml b/.cirrus.yml index 976d57365e11..8bc2e926d20e 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -42,7 +42,7 @@ task: install_script: - apt-get update # qml test reqs: - - apt-get -y install libgl1 libegl1 libxkbcommon0 libdbus-1-3 + - apt-get -y install libgl1 libegl1 libxkbcommon0 libdbus-1-3 libleveldb-dev - pip install -r $ELECTRUM_REQUIREMENTS_CI # electrum itself: - export ELECTRUM_ECC_DONT_COMPILE=1 diff --git a/contrib/requirements/requirements-ci.txt b/contrib/requirements/requirements-ci.txt index d9644b580f97..a26f412e8739 100644 --- a/contrib/requirements/requirements-ci.txt +++ b/contrib/requirements/requirements-ci.txt @@ -1,3 +1,4 @@ pytest coverage coveralls +plyvel diff --git a/electrum/__init__.py b/electrum/__init__.py index e94ac78ac9d9..6f27c7fb7562 100644 --- a/electrum/__init__.py +++ b/electrum/__init__.py @@ -17,7 +17,7 @@ class GuiImportError(ImportError): from .version import ELECTRUM_VERSION from .util import format_satoshis from .wallet import Wallet -from .storage import WalletStorage +from .stored_dict import WalletStorage from .coinchooser import COIN_CHOOSERS from .network import Network, pick_random_server from .interface import Interface diff --git a/electrum/commands.py b/electrum/commands.py index ea57c7808ff9..26539df14a9a 100644 --- a/electrum/commands.py +++ b/electrum/commands.py @@ -274,7 +274,7 @@ async def list_wallets(self): """List wallets open in daemon""" return [ { - 'path': w.db.storage.path, + 'path': w.storage.get_path() if w.storage else None, 'synchronized': w.is_up_to_date(), 'unlocked': not w.has_password() or (w.get_unlocked_password() is not None), } @@ -298,13 +298,14 @@ async def close_wallet(self, wallet_path=None): return await self.daemon._stop_wallet(wallet_path) @command('') - async def create(self, passphrase=None, password=None, encrypt_file=True, seed_type=None, wallet_path=None): + async def create(self, passphrase=None, password=None, encrypt_file=True, seed_type=None, wallet_path=None, use_levelDB=False): """Create a new wallet. If you want to be prompted for an argument, type '?' or ':' (concealed) arg:str:passphrase:Seed extension arg:str:seed_type:The type of wallet to create, e.g. 'standard' or 'segwit' arg:bool:encrypt_file:Whether the file on disk should be encrypted with the provided password + arg:bool:use_levelDB:Create levelDB storage. Note that LevelDB storage does not support file encryption. The password will only encrypt the keystore. """ d = create_new_wallet( path=wallet_path, @@ -312,15 +313,17 @@ async def create(self, passphrase=None, password=None, encrypt_file=True, seed_t password=password, encrypt_file=encrypt_file, seed_type=seed_type, + use_levelDB=use_levelDB, config=self.config) + wallet = d['wallet'] return { 'seed': d['seed'], - 'path': d['wallet'].storage.path, + 'path': wallet.storage.get_path(), 'msg': d['msg'], } @command('') - async def restore(self, text, passphrase=None, password=None, encrypt_file=True, wallet_path=None): + async def restore(self, text, passphrase=None, password=None, encrypt_file=True, wallet_path=None, use_levelDB=False): """Restore a wallet from text. Text can be a seed phrase, a master public key, a master private key, a list of bitcoin addresses or bitcoin private keys. @@ -329,6 +332,7 @@ async def restore(self, text, passphrase=None, password=None, encrypt_file=True, arg:str:text:seed phrase arg:str:passphrase:Seed extension arg:bool:encrypt_file:Whether the file on disk should be encrypted with the provided password + arg:bool:use_levelDB:Create levelDB storage. Note that LevelDB storage does not support file encryption. The password will only encrypt the keystore. """ # TODO create a separate command that blocks until wallet is synced d = restore_wallet_from_text( @@ -337,9 +341,11 @@ async def restore(self, text, passphrase=None, password=None, encrypt_file=True, passphrase=passphrase, password=password, encrypt_file=encrypt_file, + use_levelDB=use_levelDB, config=self.config) + wallet = d['wallet'] return { - 'path': d['wallet'].storage.path, + 'path': wallet.storage.get_path(), 'msg': d['msg'], } @@ -356,7 +362,7 @@ async def password(self, password=None, new_password=None, encrypt_file=None, wa if encrypt_file is None: if not password and new_password: # currently no password, setting one now: we encrypt by default - encrypt_file = True + encrypt_file = wallet.storage.supports_file_encryption() else: encrypt_file = wallet.storage.is_encrypted() wallet.update_password(password, new_password, encrypt_storage=encrypt_file) diff --git a/electrum/daemon.py b/electrum/daemon.py index db080aa91769..b8fc3c370842 100644 --- a/electrum/daemon.py +++ b/electrum/daemon.py @@ -47,7 +47,7 @@ log_exceptions, randrange, OldTaskGroup, UserFacingException, JsonRPCError, os_chmod ) from .wallet import Wallet, Abstract_Wallet -from .storage import WalletStorage +from .stored_dict import WalletStorage from .wallet_db import WalletDB, WalletUnfinished from .commands import known_commands, Commands from .simple_config import SimpleConfig @@ -551,8 +551,7 @@ def _load_wallet( if not password: raise InvalidPassword('No password given') storage.decrypt(password) - # read data, pass it to db - db = WalletDB(storage.read(), storage=storage, upgrade=upgrade) + db = WalletDB(storage.get_stored_dict(), upgrade=upgrade) if db.get_action(): raise WalletUnfinished(db) wallet = Wallet(db, config=config) @@ -612,8 +611,8 @@ async def _stop_wallet(self, path: str) -> bool: return False await wallet.stop() if self.config.get('wallet_path') is None: - wallet_paths = [w.db.storage.path for w in self._wallets.values() - if w.db.storage and w.db.storage.path] + wallet_paths = [w.storage.path for w in self._wallets.values() + if w.storage and w.storage.path] if self.config.CURRENT_WALLET == path and wallet_paths: self.config.CURRENT_WALLET = wallet_paths[0] return True diff --git a/electrum/gui/qml/qedaemon.py b/electrum/gui/qml/qedaemon.py index b8509bdd9769..5aa620434dd9 100644 --- a/electrum/gui/qml/qedaemon.py +++ b/electrum/gui/qml/qedaemon.py @@ -13,7 +13,8 @@ from electrum.lnchannel import ChannelState from electrum.bitcoin import is_address from electrum.bitcoin import verify_usermessage_with_address -from electrum.storage import StorageReadWriteError, WalletStorage +from electrum.storage import StorageReadWriteError +from electrum.stored_dict import WalletStorage from .auth import AuthMixin, auth_protect from .qefx import QEFX diff --git a/electrum/gui/qml/qewallet.py b/electrum/gui/qml/qewallet.py index c51d9bb5c706..090d91bc97a1 100644 --- a/electrum/gui/qml/qewallet.py +++ b/electrum/gui/qml/qewallet.py @@ -751,7 +751,8 @@ def setPassword(self, password): try: self._logger.info('setting new password') - self.wallet.update_password(current_password, password, encrypt_storage=True) + encrypt_storage = self.wallet.storage.supports_file_encryption() + self.wallet.update_password(current_password, password, encrypt_storage=encrypt_storage) # restore the invariant that all loaded wallets in qml must be unlocked: self.wallet.unlock(password) return True diff --git a/electrum/gui/qt/history_list.py b/electrum/gui/qt/history_list.py index c1d44fff8a05..dcafce477d72 100644 --- a/electrum/gui/qt/history_list.py +++ b/electrum/gui/qt/history_list.py @@ -142,6 +142,8 @@ def get_data_for_role(self, index: QModelIndex, role: Qt.ItemDataRole) -> QVaria assert index.isValid() col = index.column() window = self.model.window + if not window.isVisible(): + return tx_item = self.get_data() is_lightning = tx_item.get('lightning', False) if not is_lightning and 'txid' not in tx_item: diff --git a/electrum/gui/qt/main_window.py b/electrum/gui/qt/main_window.py index 6e1c35d3e797..23707a80d1c1 100644 --- a/electrum/gui/qt/main_window.py +++ b/electrum/gui/qt/main_window.py @@ -1950,8 +1950,12 @@ def on_password(hw_dev_pw): self.update_lock_menu() def _update_wallet_password(self, *, old_password, new_password, xpub_encrypt=False): + encrypt_storage = self.wallet.storage.supports_file_encryption() try: - self.wallet.update_password(old_password, new_password, encrypt_storage=True, xpub_encrypt=xpub_encrypt) + self.wallet.update_password( + old_password, new_password, + encrypt_storage=encrypt_storage, + xpub_encrypt=xpub_encrypt) except InvalidPassword as e: self.show_error(str(e)) return diff --git a/electrum/gui/qt/wizard/wallet.py b/electrum/gui/qt/wizard/wallet.py index 8397b7ba31d4..ccf60904cc5e 100644 --- a/electrum/gui/qt/wizard/wallet.py +++ b/electrum/gui/qt/wizard/wallet.py @@ -248,8 +248,8 @@ def __init__(self, parent, wizard): path = wizard._path - if os.path.isdir(path): - raise Exception("wallet path cannot point to a directory") + #if os.path.isdir(path): + # raise Exception("wallet path cannot point to a directory") self.wallet_exists = False self.wallet_is_open = False diff --git a/electrum/invoices.py b/electrum/invoices.py index 7d6a8bbbbef5..7525b71200cf 100644 --- a/electrum/invoices.py +++ b/electrum/invoices.py @@ -236,7 +236,7 @@ def get_id(self) -> str: else: # on-chain return get_id_from_onchain_outputs(outputs=self.get_outputs(), timestamp=self.time) - def as_dict(self, status): + def export(self, status): d = { 'is_lightning': self.is_lightning(), 'amount_BTC': format_satoshis(self.get_amount_sat()), diff --git a/electrum/json_db.py b/electrum/json_db.py index e7d78f97a50a..5296290e938a 100644 --- a/electrum/json_db.py +++ b/electrum/json_db.py @@ -30,15 +30,12 @@ import jsonpatch import jsonpointer -from . import util from .util import WalletFileException, profiler, sticky_property from .logging import Logger -from .stored_dict import StoredDict, _FLEX_KEY, registered_names, registered_keys, _convert_dict_key, _convert_dict_value +from .stored_dict import FLEX_KEY, BaseDB, json_default +from .storage import FileStorage -if TYPE_CHECKING: - from .storage import WalletStorage - # We monkeypatch exceptions in the jsonpatch package to ensure they do not contain secrets from the DB. # We often log exceptions and offer to send them to the crash reporter, so they must not contain secrets. @@ -54,9 +51,11 @@ setattr(jsonpatch.JsonPatchException, '__suppress_context__', sticky_property(True)) -def key_path(path: Sequence[_FLEX_KEY], key: _FLEX_KEY) -> str: - def to_str(x: _FLEX_KEY) -> str: - assert isinstance(x, _FLEX_KEY), repr(x) + + +def key_path(path: Sequence[FLEX_KEY], key: FLEX_KEY) -> str: + def to_str(x: FLEX_KEY) -> str: + assert isinstance(x, FLEX_KEY), repr(x) assert x is not None if isinstance(x, int): return str(int(x)) @@ -69,6 +68,7 @@ def to_str(x: _FLEX_KEY) -> str: return '/'.join(items) + def modifier(func): def wrapper(self, *args, **kwargs): with self.lock: @@ -85,35 +85,143 @@ def wrapper(self, *args, **kwargs): -class JsonDB(Logger): +class JsonDB(BaseDB): def __init__( - self, - s: str, - *, - storage: Optional['WalletStorage'] = None, - encoder=None, - upgrader=None, + self, + path: Optional[str], + *, + allow_partial_writes = True, + init_db = True, ): - Logger.__init__(self) + BaseDB.__init__(self, path) self.lock = threading.RLock() - self.storage = storage - self.encoder = encoder self.pending_changes = [] # type: List[str] self._modified = False - # load data - data = self.load_data(s) - if upgrader: - data, was_upgraded = upgrader(data) - self._modified |= was_upgraded - # convert json to python objects - data = self._convert_dict([], data) - # convert dict to StoredDict - self.data = StoredDict(data, self) - self.data.set_parent(key='', parent=None) + if self.path: + self.storage = FileStorage(path, allow_partial_writes=allow_partial_writes) + if init_db and not self.is_encrypted(): + # open DB if file is not encrypted + # otherwise, this will be called in self.decrypt + self.init_db() + else: + self.storage = None + self.json_data = {} + + def set_data(self, json_str): + self.json_data = self.load_data(json_str) + + def init_db(self): + if self.storage.is_encrypted(): + assert self.storage.is_past_initial_decryption() + json_str = self.storage.read() + self.json_data = self.load_data(json_str) # write file in case there was a db upgrade - if self.storage and self.storage.file_exists(): - self.write_and_force_consolidation() + self.write_and_force_consolidation() + + def decrypt(self, password: str): + self.storage.decrypt(password) + json_str = self.storage.read() + self.json_data = self.load_data(json_str) + + def check_password(self, password): + self.storage.check_password(password) + + def basename(self) -> str: + return self.storage.basename() if self.storage else 'no name' + + def supports_file_encryption(self): + return bool(self.storage) + + def get_encryption_version(self): + return self.storage.get_encryption_version() + + def is_encrypted(self): + return self.storage and self.storage.is_encrypted() + + def is_encrypted_with_user_pw(self) -> bool: + return self.storage and self.storage.is_encrypted_with_user_pw() + + def is_encrypted_with_hw_device(self) -> bool: + return self.storage and self.storage.is_encrypted_with_hw_device() + + def set_password(self, password: str, enc_version=None): + self.storage.set_password(password, enc_version=enc_version) + + def file_exists(self): + return self.storage and self.storage.file_exists() + + def _subdict(self, path): + d = self.json_data + for k in path[1:]: + d = d[k] + return d + + def iter_keys(self, path): + d = self._subdict(path) + return d.__iter__() + + def dict_len(self, path): + d = self._subdict(path) + return len(d) + + def contains(self, path, key): + d = self._subdict(path) + return key in d + + def replace(self, path, key, value): + # called by setattr + self.put(path, key, value) + + @modifier + def put(self, path, key, value): + d = self._subdict(path) + value = json.loads(json.dumps(value, default=json_default)) # default() is applied recursively + is_new = key not in d + d[key] = value + self.db_add(path, key, value) if is_new else self.db_replace(path, key, value) + + @modifier + def clear(self, path): + d = self._subdict(path) + d.clear() + + def get(self, path, key): + d = self._subdict(path) + return d[key] + + @modifier + def remove(self, path, key): + d = self._subdict(path) + d.pop(key) + self.db_remove(path, key) + + @modifier + def list_append(self, path, item): + _list = self._subdict(path) + _list.append(item) + n = len(_list) + self.db_add(path, str(n), item) + + def list_index(self, path, item): + _list = self._subdict(path) + return _list.index(item) + + def list_len(self, path): + _list = self._subdict(path) + return len(_list) + + @modifier + def list_clear(self, path): + _list = self._subdict(path) + _list.clear() # fixme + + @modifier + def list_remove(self, path, item): + _list = self._subdict(path) + n = _list.index(item) + _list.remove(item) + self.db_remove(path, str(n)) # fixme: keys def load_data(self, s: str) -> dict: if s == '': @@ -181,101 +289,51 @@ def modified(self): @locked def add_patch(self, patch): - self.pending_changes.append(json.dumps(patch, cls=self.encoder)) + self.pending_changes.append(json.dumps(patch, default=json_default)) self.set_modified(True) - def add(self, path, key: _FLEX_KEY, value) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) + def db_add(self, path, key: FLEX_KEY, value) -> None: + assert isinstance(key, FLEX_KEY), repr(key) self.add_patch({'op': 'add', 'path': key_path(path, key), 'value': value}) - def replace(self, path, key: _FLEX_KEY, value) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) + def db_replace(self, path, key: FLEX_KEY, value) -> None: + assert isinstance(key, FLEX_KEY), repr(key) self.add_patch({'op': 'replace', 'path': key_path(path, key), 'value': value}) - def remove(self, path, key: _FLEX_KEY) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) + def db_remove(self, path, key: FLEX_KEY) -> None: + assert isinstance(key, FLEX_KEY), repr(key) self.add_patch({'op': 'remove', 'path': key_path(path, key)}) - @locked - def get(self, key, default=None): - v = self.data.get(key) - if v is None: - v = default - return v - - @modifier - def put(self, key, value): - try: - json.dumps(key, cls=self.encoder) - json.dumps(value, cls=self.encoder) - except Exception: - self.logger.info(f"json error: cannot save {repr(key)} ({repr(value)})") - return False - if value is not None: - if self.data.get(key) != value: - self.data[key] = copy.deepcopy(value) - return True - elif key in self.data: - self.data.pop(key) - return True - return False - - @locked - def get_dict(self, name) -> dict: - # Warning: interacts un-intuitively with 'put': certain parts - # of 'data' will have pointers saved as separate variables. - if name not in self.data: - self.data[name] = {} - return self.data[name] - - @locked - def get_stored_item(self, key, default) -> dict: - if key not in self.data: - self.data[key] = default - return self.data[key] - @locked def dump(self, *, human_readable: bool = True) -> str: """Serializes the DB as a string. 'human_readable': makes the json indented and sorted, but this is ~2x slower """ return json.dumps( - self.data, + self.json_data, indent=4 if human_readable else None, sort_keys=bool(human_readable), - cls=self.encoder, + default=json_default, ) - def _should_convert_to_stored_dict(self, key) -> bool: - return True - - def _convert_dict_key(self, path: List[str], key: str) -> _FLEX_KEY: - return _convert_dict_key(path, key) - - def _convert_dict_value(self, path: List[str], v) -> Any: - v = _convert_dict_value(path, v) - if isinstance(v, dict): - v = self._convert_dict(path, v) - return v - - def _convert_dict(self, path: List[str], data: dict): - # recursively convert json dict to StoredDict - assert all(isinstance(x, str) for x in path), repr(path) - d = {} - for k, v in list(data.items()): - child_path = path + [k] - k = self._convert_dict_key(path, k) - v = self._convert_dict_value(child_path, v) - d[k] = v - return d - @locked def write(self): - if self.storage.should_do_full_write_next(): + if not self.storage: + return + if not self._write_batch and self.storage.should_do_full_write_next(): self.write_and_force_consolidation() else: self._append_pending_changes() + def set_write_batch(self): + self._write_batch = True + + def clear_write_batch(self): + self._write_batch = False + + def close(self): + pass + @locked def _append_pending_changes(self): if threading.current_thread().daemon: diff --git a/electrum/keystore.py b/electrum/keystore.py index 0d7fd8e34f10..1e440250a0e6 100644 --- a/electrum/keystore.py +++ b/electrum/keystore.py @@ -1126,14 +1126,14 @@ def hardware_keystore(d) -> Hardware_KeyStore: f'hw_keystores: {list(hw_keystores)}') def load_keystore(db: 'WalletDB', name: str) -> KeyStore: - # deepcopy object to avoid keeping a pointer to db.data - # note: this is needed as type(wallet.db.get("keystore")) != StoredDict - d = copy.deepcopy(db.get(name, {})) + x = db.get(name) + if x is None: + raise WalletFileException('Cannot find keystore for name {}'.format(name)) + # convert StoredDict to dict + d = x.as_dict() t = d.get('type') if not t: - raise WalletFileException( - 'Wallet format requires update.\n' - 'Cannot find keystore for name {}'.format(name)) + raise WalletFileException('Cannot find keystore for name {}'.format(name)) keystore_constructors = {ks.type: ks for ks in [Old_KeyStore, Imported_KeyStore, BIP32_KeyStore]} keystore_constructors['hardware'] = hardware_keystore try: diff --git a/electrum/level_db.py b/electrum/level_db.py new file mode 100644 index 000000000000..24d2a4ae7d32 --- /dev/null +++ b/electrum/level_db.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +import json +import threading +import os +from typing import Any, Optional, Tuple, Union, Iterator, Iterable +import plyvel + +from .stored_dict import BaseDB, FLEX_KEY, key_to_str, json_default + +# Todo: +# - simplify path: first element is unused + + +def locked(func): + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + return wrapper + + +class JsonCodec: + """Default value codec: JSON (utf-8).""" + @staticmethod + def dumps(value: Any) -> bytes: + return json.dumps(value, separators=(",", ":"), ensure_ascii=False, default=json_default).encode("utf-8") + + @staticmethod + def loads(data: bytes) -> Any: + return json.loads(data.decode("utf-8")) + + +def _to_bytes_key(k: str) -> bytes: + if isinstance(k, int): + k = str(k) + return k.encode("utf-8") + +def _to_str_key(k: bytes) -> str: + return k.decode("utf-8") + + +# bump this if key/value encoding changes +STORAGE_VERSION = str(0).encode('utf-8') +VERSION_FILENAME = 'ELECTRUM_LEVELDB_VERSION' + +class LevelDB(BaseDB): + + def __init__( + self, + path: str, + init_db = True, + ): + assert path # in-memory only is only allowed with JsonDB + BaseDB.__init__(self, path) + self.lock = threading.RLock() + self.delimiter = "/" + self.codec = JsonCodec + if init_db: + self.init_db() + + def basename(self) -> str: + return os.path.basename(self.path) + + def is_encrypted(self): + return False + + def file_exists(self): + return os.path.exists(self.path) + + def supports_file_encryption(self): + return False + + def is_encrypted_with_hw_device(self): + return False + + def is_encrypted_with_user_pw(self): + return False + + def init_db(self): + # if path exists, check version file + version_file = os.path.join(self.path, VERSION_FILENAME) + if os.path.exists(self.path): + if not os.path.exists(version_file): + raise Exception('Not an Electrum DB') + with open(version_file, "rb") as f: + v = f.read() + # no upgrades support for the moment + if v != STORAGE_VERSION: + raise Exception('Unsupported DB version') + # create DB + # according to the docs, setting write_buffer_size + # to zero forces levelDB to write directly to disk + self.db = plyvel.DB( + self.path, + create_if_missing=True, + write_buffer_size=0, + ) + # create version file + if not os.path.exists(version_file): + with open(version_file, "wb") as f: + f.write(STORAGE_VERSION) + # set permissions + self._set_permissions() + + def _set_permissions(self): + os.chmod(self.path, 0o700) + for path, dirs, files in os.walk(self.path): + for x in files: os.chmod(os.path.join(path, x), 0o600) + for x in dirs: os.chmod(os.path.join(path, x), 0o700) + + def _debug(self): + for k, v in self.db.iterator(): + self.logger.info(f"{k} -> {v}") + + def close(self) -> None: + if self.db is not None: + self.logger.info('closing database') + self.db.close() + self.db = None + + def set_modified(self, b): + # fixme: callers should not have to do that + pass + + def __enter__(self) -> "LevelDB": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + + def write(self): + pass + + def write_and_force_consolidation(self): + # called after password update. + # remove remnants encrypted with old passwordd + self.db.compact_range() + + def _prefix_bytes(self, path) -> bytes: + assert path[0] == '' + d = self.delimiter.encode("utf-8") + p = d.join([_to_bytes_key(x) for x in path]) + if not p: + return b"" + # Ensure exactly one trailing delimiter for internal prefix usage + if p.endswith(d): + p = p[:-len(d)] + return p + + def _full_key(self, path, key: FLEX_KEY) -> bytes: + return self._prefix_bytes(path + [key]) + + def _child_prefix(self, path, key: FLEX_KEY) -> bytes: + d = self.delimiter.encode("utf-8") + return self._full_key(path, key) + d + + def _has_children(self, path, key: FLEX_KEY) -> bool: + db = self.db + if db is None: + raise RuntimeError("DB is closed") + pfx = self._child_prefix(path, key) + it = db.iterator(prefix=pfx, include_value=False) + try: + next(it) + return True + except StopIteration: + return False + + def iter_keys(self, path) -> Iterator[str]: + """ + Iterate unique top-level keys at this view's prefix. + """ + db = self.db + if db is None: + raise RuntimeError("DB is closed") + d = self.delimiter.encode("utf-8") + pb = self._prefix_bytes(path) + d + seen = set() + for k, _v in db.iterator(prefix=pb): + rel = k[len(pb):] if pb else k + first = rel.split(d, 1)[0] + if first not in seen: + seen.add(first) + yield _to_str_key(first) + + @locked + def remove(self, path, key): + self._delete_subtree(path, key, wb=None) + + def _delete_subtree(self, path, key: FLEX_KEY, wb: Optional[plyvel.WriteBatch] = None) -> None: + db = self.db + if db is None: + raise RuntimeError("DB is closed") + pfx = self._child_prefix(path, key) + it = db.iterator(prefix=pfx, include_value=False) + deleter = wb.delete if wb is not None else db.delete + for k in it: + deleter(k) + # delete scalar at node itself, if present + k = self._full_key(path, key) + deleter(k) + if wb is None: + r = db.get(k) + assert r is None, r + + @locked + def clear(self, path) -> None: + db = self.db + if db is None: + raise RuntimeError("DB is closed") + pb = self._prefix_bytes(path) + with db.write_batch() as wb: + for k, _v in db.iterator(prefix=pb): + wb.delete(k) + + @locked + def get(self, path, key: FLEX_KEY) -> Any: + db = self.db + if db is None: + raise RuntimeError("DB is closed") + raw = db.get(self._full_key(path, key)) + if raw is None: + raise KeyError((path, key, self._full_key(path, key))) + return self.codec.loads(raw) # json to python + + def _flatten_into_batch(self, base_key: bytes, value: Any, wb: plyvel.WriteBatch) -> None: + d = self.delimiter.encode("utf-8") + if isinstance(value, dict): + wb.put(base_key, self.codec.dumps({})) + for k, v in value.items(): + k = key_to_str(k) + child_key = base_key + d + _to_bytes_key(k) + self._flatten_into_batch(child_key, v, wb) + elif isinstance(value, list): + wb.put(base_key, self.codec.dumps([])) + for k, v in enumerate(value): + k = key_to_str(k) + child_key = base_key + d + _to_bytes_key(k) + self._flatten_into_batch(child_key, v, wb) + else: + wb.put(base_key, self.codec.dumps(value)) + + def set_write_batch(self): + self._write_batch = self.db.write_batch() + + def clear_write_batch(self): + self._write_batch = None + + def get_write_batch(self): + if self._write_batch: + return self._write_batch + else: + return self.db.write_batch() + + @locked + def put(self, path, key: FLEX_KEY, value: Any) -> None: + db = self.db + if db is None: + raise RuntimeError("DB is closed") + with self.get_write_batch() as wb: + # delete any pre-existing dict + self._delete_subtree(path, key, wb=wb) + if isinstance(value, (list, dict)): + base = self._full_key(path, key) + # do not store marker at "key"; only descendants + self._flatten_into_batch(base, value, wb) + else: + wb.put(self._full_key(path, key), self.codec.dumps(value)) + + @locked + def replace(self, path, key: FLEX_KEY, value: Any) -> None: + # called by StoredObject in setattr + db = self.db + if db is None: + raise RuntimeError("DB is closed") + fullkey = self._full_key(path[:-1], path[-1]) + d = self.codec.loads(db.get(fullkey)) + d[key] = value + db.put(fullkey, self.codec.dumps(d)) + + @locked + def contains(self, path, key: object) -> bool: + db = self.db + if db is None: + raise RuntimeError("DB is closed") + if db.get(self._full_key(path, key)) is not None: + return True + return False #self._has_children(path, key) + + + # list methods + + @locked + def list_append(self, path, item): + n = self.list_len(path) + self.put(path, str(n), item) + + @locked + def list_clear(self, path): + path, key = path[:-1], path[-1] + #self._delete_subtree(path, key, wb=None) + self.put(path, key, []) + + @locked + def dict_len(self, path): + # fixme: slow + return len(list(self.iter_keys(path))) + + @locked + def list_len(self, path): + return len(list(self.iter_keys(path))) + + @locked + def list_index(self, path, item): + for k in self.iter_keys(path): + v = self.get(path, k) + if item == v: + return int(k) + raise KeyError(item) + + @locked + def list_remove(self, path, item): + k = self.list_index(path, item) + n = self.list_len(path) + for i in range(k, n-1): + self.put(path, str(i), self.get(path, str(i+1))) + self.remove(path, str(n-1)) diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 49878dec5f54..8b6a1feb9a6e 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -64,10 +64,11 @@ from .lnutil import ChannelBackupStorage, ImportedChannelBackupStorage, OnchainChannelBackupStorage from .lnutil import format_short_channel_id from .fee_policy import FEERATE_PER_KW_MIN_RELAY_LIGHTNING +from .stored_dict import stored_at if TYPE_CHECKING: from .lnworker import LNWallet - from .json_db import StoredDict + from .stored_dict import StoredDict # channel flags @@ -795,7 +796,7 @@ def __init__( Logger.__init__(self) # should be after short_channel_id is set self.lnworker = lnworker self.storage = state - self.db_lock = self.storage.lock + self.db_lock = threading.RLock() if isinstance(self.storage, dict) else self.storage.lock self.config = {} self.config[LOCAL] = state["local_config"] self.config[REMOTE] = state["remote_config"] @@ -804,7 +805,7 @@ def __init__( self.node_id = bfh(state["node_id"]) self.onion_keys = state['onion_keys'] # type: Dict[int, bytes] self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] - self.hm = HTLCManager(log=state['log'], initiator = LOCAL if self.constraints.is_initiator else REMOTE, initial_feerate=initial_feerate) + self.hm = HTLCManager(log=state['log'], initiator = LOCAL if self.constraints.is_initiator else REMOTE, initial_feerate=initial_feerate, lock=self.db_lock) self.unfulfilled_htlcs = state["unfulfilled_htlcs"] # type: Dict[int, Optional[str]] # ^ htlc_id -> onion_packet_hex self._state = ChannelState[state['state']] diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index 9901efe4db20..2bdad634760f 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -1,11 +1,12 @@ from copy import deepcopy from typing import Sequence, Tuple, Dict, TYPE_CHECKING, Set +import threading from .lnutil import SENT, RECEIVED, LOCAL, REMOTE, HTLCOwner, UpdateAddHtlc, Direction, FeeUpdate from .util import bfh, with_lock if TYPE_CHECKING: - from .json_db import StoredDict + from .stored_dict import StoredDict LOG_TEMPLATE = { 'adds': {}, # "side who offered htlc" -> htlc_id -> htlc @@ -21,7 +22,7 @@ class HTLCManager: - def __init__(self, log: 'StoredDict', *, initiator=None, initial_feerate=None): + def __init__(self, log: 'StoredDict', *, initiator=None, initial_feerate=None, lock=None): if len(log) == 0: # note: "htlc_id" keys in dict are str! but due to json_db magic they can *almost* be treated as int... @@ -41,7 +42,7 @@ def __init__(self, log: 'StoredDict', *, initiator=None, initial_feerate=None): # lnchannel sometimes calls us with Channel.db_lock (== log.lock) already taken, # and we ourselves often take log.lock (via StoredDict.__getitem__). # Hence, to avoid deadlocks, we reuse this same lock. - self.lock = log.lock + self.lock = lock if lock else threading.RLock() self._init_maybe_active_htlc_ids() diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index d4551cc5cc7d..8408361ef810 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -52,7 +52,6 @@ from .lntransport import LNTransport, LNTransportBase, LightningPeerConnectionClosed, HandshakeFailed from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType, FailedToParseMsg from .interface import GracefulDisconnect -from .json_db import StoredDict from .invoices import PR_PAID from .fee_policy import FEE_LN_ETA_TARGET, FEERATE_PER_KW_MIN_RELAY_LIGHTNING from .channel_db import FLAG_DIRECTION @@ -1170,7 +1169,7 @@ async def channel_establishment_flow( lnworker=self.lnworker, initial_feerate=feerate ) - chan.storage['funding_inputs'] = [txin.prevout.to_json() for txin in funding_tx.inputs()] + chan.storage['funding_inputs'] = [txin.prevout for txin in funding_tx.inputs()] chan.storage['has_onchain_backup'] = has_onchain_backup chan.storage['init_height'] = self.lnworker.network.get_local_height() chan.storage['init_timestamp'] = int(time.time()) @@ -1220,7 +1219,9 @@ def create_channel_storage(self, channel_id, outpoint, local_config, remote_conf "revocation_store": {}, "channel_type": channel_type, } - return StoredDict(chan_dict, self.lnworker.db) + channels_db = self.lnworker.db.get_dict('channels') + channels_db[channel_id.hex()] = chan_dict + return channels_db[channel_id.hex()] @non_blocking_msg_handler async def on_open_channel(self, payload): diff --git a/electrum/lnutil.py b/electrum/lnutil.py index d46b2d4ac6e9..721648f11323 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -1940,7 +1940,7 @@ def from_tuple(amount_msat, rhash, cltv_abs, htlc_id, timestamp) -> 'UpdateAddHt htlc_id=htlc_id, timestamp=timestamp) - def to_json(self): + def as_tuple(self): self._validate() return dataclasses.astuple(self) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 41447c63e877..c4b48b371d55 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1660,8 +1660,8 @@ def add_channel(self, chan: Channel): def add_new_channel(self, chan: Channel): self.add_channel(chan) - channels_db = self.db.get_dict('channels') - channels_db[chan.channel_id.hex()] = chan.storage + #channels_db = self.db.get_dict('channels') + #channels_db[chan.channel_id.hex()] = chan.storage self.wallet.set_reserved_addresses_for_chan(chan, reserved=True) try: self.save_channel(chan) diff --git a/electrum/plugins/watchtower/watchtower.py b/electrum/plugins/watchtower/watchtower.py index 4b161304e017..a9c78bae7da1 100644 --- a/electrum/plugins/watchtower/watchtower.py +++ b/electrum/plugins/watchtower/watchtower.py @@ -36,6 +36,7 @@ from electrum.network import Network from electrum.address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL from electrum.wallet_db import WalletDB +from electrum.stored_dict import WalletStorage from electrum.lnutil import WITNESS_TEMPLATE_RECEIVED_HTLC, WITNESS_TEMPLATE_OFFERED_HTLC from electrum.logging import Logger from electrum.util import EventListener, event_listener @@ -67,7 +68,8 @@ class WatchTower(Logger, EventListener): def __init__(self, network: 'Network'): Logger.__init__(self) self.config = network.config - wallet_db = WalletDB('', storage=None, upgrade=True) + json_db = WalletStorage(None) + wallet_db = WalletDB(json_db.get_stored_dict()) self.adb = AddressSynchronizer(wallet_db, self.config, name=self.diagnostic_name()) self.adb.start_network(network) self.callbacks = {} # address -> lambda function diff --git a/electrum/storage.py b/electrum/storage.py index 7c79f2de5acc..b4d60235e8e1 100644 --- a/electrum/storage.py +++ b/electrum/storage.py @@ -28,7 +28,6 @@ import hashlib import base64 import zlib -from enum import IntEnum from typing import Optional import electrum_ecc as ecc @@ -37,7 +36,6 @@ from .util import (profiler, InvalidPassword, WalletFileException, bfh, standardize_path, test_read_write_permissions, os_chmod) -from .wallet_db import WalletDB from .logging import Logger @@ -47,10 +45,7 @@ def get_derivation_used_for_hw_device_encryption(): "/1112098098'") # ascii 'BIE2' as decimal -class StorageEncryptionVersion(IntEnum): - PLAINTEXT = 0 - USER_PASSWORD = 1 - XPUB_PASSWORD = 2 +from .stored_dict import StorageEncryptionVersion class StorageReadWriteError(Exception): pass @@ -59,8 +54,7 @@ class StorageReadWriteError(Exception): pass class StorageOnDiskUnexpectedlyChanged(Exception): pass -# TODO: Rename to Storage -class WalletStorage(Logger): +class FileStorage(Logger): # TODO maybe split this into separate create() and open() classmethods, to prevent some bugs. # Until then, the onus is on the caller to check file_exists(). @@ -74,6 +68,7 @@ def __init__( self.path = standardize_path(path) self._file_exists = bool(self.path and os.path.exists(self.path)) self.logger.info(f"wallet path {self.path}") + self._allow_partial_writes = allow_partial_writes self.pubkey = None self.decrypted = '' diff --git a/electrum/stored_dict.py b/electrum/stored_dict.py index 08810c3222c0..256b9373cbdd 100644 --- a/electrum/stored_dict.py +++ b/electrum/stored_dict.py @@ -24,24 +24,24 @@ # SOFTWARE. import threading -import json +import os +from enum import IntEnum from collections import defaultdict -from typing import TYPE_CHECKING, Optional, Sequence, List, Union, Any - - -if TYPE_CHECKING: - from .json_db import JsonDB - from .storage import WalletStorage - +from typing import Any, Optional, Tuple, Union, Iterator, Iterable, List, Sequence +from .logging import Logger +FLEX_KEY = str | int | None +_RaiseKeyError = object() # singleton for no-default behavior -def locked(func): - def wrapper(self, *args, **kwargs): - with self.lock: - return func(self, *args, **kwargs) - return wrapper +def key_to_str(x: FLEX_KEY) -> str: + if isinstance(x, int): + return str(int(x)) + elif isinstance(x, str): + return x + else: + raise Exception(f"key {x=}") registered_names = defaultdict(dict) registered_keys = defaultdict(dict) @@ -69,10 +69,68 @@ def decorator(func): return func return decorator -_FLEX_KEY = str | int | None + +class StorageEncryptionVersion(IntEnum): + PLAINTEXT = 0 + USER_PASSWORD = 1 + XPUB_PASSWORD = 2 + + +def json_default(obj): + if isinstance(obj, (set, frozenset)): + return list(obj) + if isinstance(obj, bytes): + return obj.hex() + if hasattr(obj, 'as_str') and callable(obj.as_str): + return obj.as_str() + if hasattr(obj, 'as_dict') and callable(obj.as_dict): + return obj.as_dict() + if hasattr(obj, 'as_tuple') and callable(obj.as_tuple): + return obj.as_tuple() + return obj + + +class BaseDB(Logger): + + def __init__(self, path): + Logger.__init__(self) + self._write_batch = None + self.path = path + + def get_stored_dict(self): + return StoredDict(self, key='', parent=None) + + def file_exists(self): + raise NotImplementedError() + + def get_path(self): + return self.path + + def set_password(self, password:str): + raise NotImplementedError() -def _convert_dict_key(path: List[str], key: str) -> _FLEX_KEY: +def WalletStorage(path: str, init_db: bool = True, use_levelDB: bool = False, allow_partial_writes: bool = True) -> BaseDB: + file_exists = bool(path) and os.path.exists(path) + if not file_exists: + use_levelDB = use_levelDB + elif os.path.isdir(path): + from .level_db import VERSION_FILENAME + if not os.path.exists(os.path.join(path, VERSION_FILENAME)): + raise Exception("Not an Electrum LevelDB wallet") + use_levelDB = True + else: + use_levelDB = False + if use_levelDB: + from .level_db import LevelDB + db = LevelDB(path, init_db=init_db) + else: + from .json_db import JsonDB + db = JsonDB(path=path, init_db=init_db, allow_partial_writes=allow_partial_writes) + return db + + +def _convert_dict_key(path: List[str], key: str) -> FLEX_KEY: """Maybe convert key from str to python type (typically int or IntEnum)""" assert all(isinstance(x, str) for x in path), repr(path) n = len(path) @@ -82,7 +140,7 @@ def _convert_dict_key(path: List[str], key: str) -> _FLEX_KEY: if func: key = func(key) break - assert isinstance(key, _FLEX_KEY), f"unexpected type for {key=!r} at {path=}" + assert isinstance(key, FLEX_KEY), f"unexpected type for {key=!r} at {path=}" return key def _convert_dict_value(path: List[str], v) -> Any: @@ -106,8 +164,8 @@ def _convert_dict_value(path: List[str], v) -> Any: class BaseStoredObject: - _db: 'JsonDB' = None - _key: _FLEX_KEY = None + _db: BaseDB = None + _key: FLEX_KEY = None _parent: Optional['BaseStoredObject'] = None _lock: threading.RLock = None @@ -115,9 +173,9 @@ def set_db(self, db): self._db = db self._lock = self._db.lock if self._db else threading.RLock() - def set_parent(self, *, key: _FLEX_KEY, parent: Optional['BaseStoredObject']) -> None: + def set_parent(self, *, key: FLEX_KEY, parent: Optional['BaseStoredObject']) -> None: assert (key == "") == (parent is None), f"{key=!r}, {parent=!r}" - assert isinstance(key, _FLEX_KEY), repr(key) + assert isinstance(key, FLEX_KEY), repr(key) self._key = key self._parent = parent @@ -126,7 +184,7 @@ def lock(self): return self._lock @property - def path(self) -> Sequence[_FLEX_KEY] | None: + def path(self) -> Sequence[FLEX_KEY] | None: # return None iff we are pruned from root x = self s = [x._key] @@ -135,23 +193,18 @@ def path(self) -> Sequence[_FLEX_KEY] | None: s = [x._key] + s if x._key != '': return None - assert self._db is not None + else: + assert self._db is not None return s - def db_add(self, key: _FLEX_KEY, value) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - if self.path: - self._db.add(self.path, key, value) + def _to_stored_dict_or_list(self, key, value): + """convert list to StoredList, dict to StoredDict""" + if isinstance(value, list): + value = StoredList(self._db, key=key, parent=self) + elif isinstance(value, dict): + value = StoredDict(self._db, key=key, parent=self) + return value - def db_replace(self, key: _FLEX_KEY, value) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - if self.path: - self._db.replace(self.path, key, value) - - def db_remove(self, key: _FLEX_KEY) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - if self.path: - self._db.remove(self.path, key) class StoredObject(BaseStoredObject): @@ -159,12 +212,12 @@ class StoredObject(BaseStoredObject): def __setattr__(self, key: str, value): assert isinstance(key, str), repr(key) - if self.path and not key.startswith('_'): + if not key.startswith('_') and self.path: if value != getattr(self, key): - self.db_replace(key, value) + self._db.replace(self.path, key, value) object.__setattr__(self, key, value) - def to_json(self): + def as_dict(self): d = dict(vars(self)) # don't expose/store private stuff d = {k: v for k, v in d.items() @@ -173,95 +226,220 @@ def to_json(self): -_RaiseKeyError = object() # singleton for no-default behavior +class StoredDict(BaseStoredObject): + """ + dict-like object that queries the DB + type conversions are performed here + the DB object returns simple python objects: list or dict + this class converts them + """ -class StoredDict(dict, BaseStoredObject): - - def __init__(self, data: dict, db: 'JsonDB'): - self.set_db(db) - # recursively convert dicts to StoredDict - for k, v in list(data.items()): - self.__setitem__(k, v) - - @locked - def __setitem__(self, key: _FLEX_KEY, v) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - is_new = key not in self - # early return to prevent unnecessary disk writes - if not is_new and self._db and json.dumps(v, cls=self._db.encoder) == json.dumps(self[key], cls=self._db.encoder): - return - # convert dict to StoredDict. - if type(v) == dict and (self._db is None or self._db._should_convert_to_stored_dict(key)): - v = StoredDict(v, self._db) - # convert list to StoredList - elif type(v) == list: - v = StoredList(v, self._db) - # reject sets. they do not work well with jsonpatch - elif isinstance(v, set): - raise Exception(f"Do not store sets inside jsondb. path={self.path!r}") + def __init__(self, db: BaseDB, key: FLEX_KEY, parent): + BaseStoredObject.__init__(self) + self._db = db + self._lock = db.lock + self._parent = parent + self._key = key_to_str(key) + self._should_convert = True + + def should_convert(self): + return self._parent.should_convert() if self._parent is not None else self._should_convert + + def start_upgrade(self): + self._should_convert = False + self._db.set_write_batch() + + def end_upgrade(self): + self._should_convert = True + self._db.clear_write_batch() + + def get_dict(self, key) -> 'StoredDict': + # side effect: creates db entry if it does not exist + key = key_to_str(key) + if not self._db.contains(self.path, key): + self._db.put(self.path, key, {}) + return StoredDict(self._db, key=key, parent=self) + + def dump(self): + self._should_convert = False + data = self._dump() + self._should_convert = True + return data + + def _dump(self): + data = {} + for k, v in self.items(): + if isinstance(v, StoredDict): + v = v._dump() + if isinstance(v, StoredList): + v = v._dump() + data[k] = v + return data + + def __getitem__(self, key: FLEX_KEY) -> Any: + key = key_to_str(key) + value = self._db.get(self.path, key) + if not self.should_convert(): + return self._to_stored_dict_or_list(key, value) + value = _convert_dict_value(self.path + [key], value) # set db for StoredObject, because it is not set in the constructor - if isinstance(v, StoredObject): - v.set_db(self._db) - # set parent - if isinstance(v, BaseStoredObject): - v.set_parent(key=key, parent=self) - # set item - dict.__setitem__(self, key, v) - self.db_add(key, v) if is_new else self.db_replace(key, v) - - @locked - def __delitem__(self, key: _FLEX_KEY) -> None: - assert isinstance(key, _FLEX_KEY), repr(key) - r = self.get(key, None) - dict.__delitem__(self, key) - self.db_remove(key) - if isinstance(r, BaseStoredObject): - r._parent = None - - @locked - def pop(self, key: _FLEX_KEY, v=_RaiseKeyError) -> Any: - assert isinstance(key, _FLEX_KEY), repr(key) - if key not in self: - if v is _RaiseKeyError: - raise KeyError(key) + if isinstance(value, StoredObject): + value.set_db(self._db) + value.set_parent(key=key, parent=self) + return self._to_stored_dict_or_list(key, value) + + def __setitem__(self, key: FLEX_KEY, value: Any) -> None: + key = key_to_str(key) + if isinstance(value, StoredList): + # fixme: this only happens during db upgrade? + value = value[:] + assert isinstance(value, list) + if isinstance(value, StoredDict): + value = value._dump() + assert isinstance(value, dict) + self._db.put(self.path, key, value) + + def __delitem__(self, key: FLEX_KEY) -> None: + key = key_to_str(key) + self._db.remove(self.path, key) + + def __iter__(self) -> Iterator[str]: + return self._db.iter_keys(self.path) + + def __len__(self) -> int: + return self._db.dict_len(self.path) + + # ---- Dict-like extras ---- + + def __contains__(self, key: object) -> bool: + key = key_to_str(key) + assert isinstance(key, str) + return self._db.contains(self.path, key) + + def keys(self) -> Iterable[str]: + for k in self._db.iter_keys(self.path): + yield _convert_dict_key(self.path, k) + + def values(self) -> Iterator[Any]: + for k in self.keys(): + yield self[k] + + def items(self) -> Iterator[Tuple[str, Any]]: + for k in self.keys(): + yield (k, self[k]) + + def get(self, key: FLEX_KEY, default: Any = None) -> Any: + try: + return self[key] + except KeyError: + return default + + def clear(self) -> None: + self._db.clear(self.path) + + def pop(self, key: FLEX_KEY, default: Any = _RaiseKeyError) -> Any: + try: + v = self[key] + except KeyError: + if default is _RaiseKeyError: + raise + return default + del self[key] + #if isinstance(v, StoredDict): + # v._parent = None + # v = v._dump() + #assert(not isinstance(v, BaseStoredObject)) + return v + + def update(self, other=(), /, **kwargs) -> None: + if isinstance(other, dict): + pairs = list(other.items()) + else: + pairs = list(other) + pairs.extend(kwargs.items()) + for k, v in pairs: + self[k] = v + + def as_dict(self) -> dict: + """used by keystore""" + def f(v): + if isinstance(v, StoredDict): + return v.as_dict() + elif isinstance(v, StoredList): + return v[::] else: return v - r = dict.pop(self, key) - self.db_remove(key) - if isinstance(r, BaseStoredObject): - r._parent = None - return r - - def setdefault(self, key: _FLEX_KEY, default = None, /): - assert isinstance(key, _FLEX_KEY), repr(key) + return dict([(k, f(v)) for k, v in self.items()]) + + def setdefault(self, key: FLEX_KEY, default = None, /): + assert isinstance(key, FLEX_KEY), repr(key) if key not in self: self.__setitem__(key, default) return self[key] -class StoredList(list, BaseStoredObject): +class StoredList(BaseStoredObject): - def __init__(self, data, db: 'JsonDB'): - list.__init__(self, data) - self.set_db(db) + def __init__(self, db: BaseDB, key: FLEX_KEY, parent): + self._db = db + self._lock = db.lock + self._parent = parent + self._key = key_to_str(key) + self._should_convert = True + + def should_convert(self): + return self._parent.should_convert() if self._parent is not None else self._should_convert + + def _get_list_item(self, index: int): + value = self._db.get(self.path, index) + key = str(index) + if self.should_convert(): + value = _convert_dict_value(self.path + [key], value) + value = self._to_stored_dict_or_list(key, value) + return value + + def __getitem__(self, s: slice) -> Any: + n = self._db.list_len(self.path) + if type(s) is int: + s = n + s if s < 0 else s + return self._get_list_item(s) + elif type(s) is slice: + start = 0 if s.start is None else s.start if s.start >= 0 else n + s.start + stop = n if s.stop is None else s.stop if s.stop >= 0 else n + s.stop + step = 1 if s.step is None else s.step + return [self._get_list_item(i) for i in range(start, stop, step)] + else: + raise Exception() + + def __len__(self): + return self._db.list_len(self.path) + + def __iter__(self) -> Iterator[str]: + for i in range(self._db.list_len(self.path)): + yield self._get_list_item(i) - @locked def append(self, item): - n = len(self) - list.append(self, item) - self.db_add('%d'%n, item) + self._db.list_append(self.path, item) - @locked - def remove(self, item): - n = self.index(item) - list.remove(self, item) - self.db_remove('%d'%n) - - @locked def clear(self): - list.clear(self) - self.db_replace(None, []) - + self._db.list_clear(self.path) + assert len(self) == 0 + def index(self, item) -> int: + return self._db.list_index(self.path, item) + def remove(self, item): + self._db.list_remove(self.path, item) + + def _dump(self): + data = [] + for v in self: + if isinstance(v, list): + raise + if isinstance(v, StoredDict): + v = v._dump() + if isinstance(v, StoredList): + v = v._dump() + data.append(v) + return data diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index f83d653f9d9d..c713e6e2c7bd 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -38,6 +38,7 @@ ) from . import lnutil from .lnutil import hex_to_bytes, REDEEM_AFTER_DOUBLE_SPENT_DELAY, Keypair + from .bolt11 import decode_bolt11_invoice from .stored_dict import StoredObject, stored_at from . import constants @@ -280,6 +281,7 @@ def start_network(self, network: 'Network'): for k, swap in swaps_items: if swap.is_redeemed: continue + swap._payment_hash = bytes.fromhex(k) self.add_lnwatcher_callback(swap) asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop) diff --git a/electrum/transaction.py b/electrum/transaction.py index 94d40630db02..06e5a38c868e 100644 --- a/electrum/transaction.py +++ b/electrum/transaction.py @@ -164,6 +164,9 @@ def to_legacy_tuple(self) -> Tuple[int, str, Union[int, str]]: return TYPE_ADDRESS, self.address, self.value return TYPE_SCRIPT, self.scriptpubkey.hex(), self.value + def as_tuple(self): + return self.to_legacy_tuple() + @classmethod def from_legacy_tuple(cls, _type: int, addr: str, val: Union[int, str]) -> Union['TxOutput', 'PartialTxOutput']: if _type == TYPE_ADDRESS: @@ -907,6 +910,9 @@ class Transaction: def __str__(self): return self.serialize() + def as_str(self): + return str(self) + def __init__(self, raw): if raw is None: self._cached_network_ser = None diff --git a/electrum/txbatcher.py b/electrum/txbatcher.py index 48279b40e0e0..f3f779b73082 100644 --- a/electrum/txbatcher.py +++ b/electrum/txbatcher.py @@ -76,11 +76,18 @@ from .transaction import PartialTransaction, PartialTxOutput, Transaction, TxOutpoint, PartialTxInput from .address_synchronizer import TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE from .lnsweep import SweepInfo -from .json_db import locked, StoredDict from .fee_policy import FeePolicy if TYPE_CHECKING: from .wallet import Abstract_Wallet + from .stored_dict import StoredDict + + +def locked(func): + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + return wrapper class TxBatcher(Logger): @@ -90,7 +97,7 @@ class TxBatcher(Logger): def __init__(self, wallet: 'Abstract_Wallet'): Logger.__init__(self) self.lock = threading.RLock() - self.storage = wallet.db.get_stored_item("tx_batches", {}) + self.storage = wallet.db.get_dict("tx_batches") self.tx_batches = {} # type: Dict[str, TxBatch] self.wallet = wallet for key, item_storage in self.storage.items(): @@ -228,7 +235,7 @@ def get_password_future(self, txid: str): class TxBatch(Logger): - def __init__(self, wallet: 'Abstract_Wallet', storage: StoredDict): + def __init__(self, wallet: 'Abstract_Wallet', storage: 'StoredDict'): Logger.__init__(self) self.wallet = wallet self.storage = storage diff --git a/electrum/util.py b/electrum/util.py index b4139412b2d3..a7d09c7c6d69 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -68,6 +68,7 @@ from .i18n import _ from .logging import get_logger, Logger +from .stored_dict import stored_at if TYPE_CHECKING: from .network import Network, ProxySettings @@ -340,6 +341,8 @@ def default(self, obj): return obj.hex() if hasattr(obj, 'to_json') and callable(obj.to_json): return obj.to_json() + if hasattr(obj, 'as_tuple') and callable(obj.as_tuple): + return obj.as_tuple() return super(MyEncoder, self).default(obj) @@ -1253,6 +1256,19 @@ class TxMinedInfo: header_hash: Optional[str] = None # hash of block that mined tx wanted_height: Optional[int] = None # in case of timelock, min abs block height + def as_tuple(self): + return (self._height, self.timestamp, self.txpos, self.header_hash) + + @staticmethod + @stored_at('verified_tx3/*', tuple) + def from_tuple(height, timestamp, txpos, header_hash): + return TxMinedInfo( + _height=height, + timestamp=timestamp, + txpos=txpos, + header_hash=header_hash, + ) + def height(self) -> int: """Treat unverified heights as unconfirmed.""" h = self._height diff --git a/electrum/wallet.py b/electrum/wallet.py index 0de4221029a8..a21cd920c86f 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -63,7 +63,6 @@ ) from .simple_config import SimpleConfig from .fee_policy import FeePolicy, FixedFeePolicy, FEE_RATIO_HIGH_WARNING, FEERATE_WARNING_HIGH_FEE -from .storage import StorageEncryptionVersion, WalletStorage from .wallet_db import WalletDB from .transaction import ( Transaction, TxInput, TxOutput, PartialTransaction, PartialTxInput, PartialTxOutput, TxOutpoint, Sighash @@ -82,6 +81,7 @@ from .descriptor import Descriptor from .txbatcher import TxBatcher from .submarine_swaps import MIN_SWAP_AMOUNT_SAT +from .stored_dict import WalletStorage, StorageEncryptionVersion if TYPE_CHECKING: from .network import Network @@ -407,7 +407,7 @@ def __init__(self, db: WalletDB, *, config: SimpleConfig): self.config = config assert self.config is not None, "config must not be None" self.db = db - self.storage = db.storage # type: Optional[WalletStorage] + self.storage = db.data._db # type: BaseDB # load addresses needs to be called before constructor for sanity checks db.load_addresses(self.wallet_type) self.keystore = None # type: Optional[KeyStore] # will be set by load_keystore @@ -458,7 +458,7 @@ def __init__(self, db: WalletDB, *, config: SimpleConfig): self.up_to_date_changed_event = asyncio.Event() assert self.db.get('genesis_blockhash') == constants.net.GENESIS, self.db.get('genesis_blockhash') - if self.storage and self.has_storage_encryption(): + if self.has_storage_encryption(): if (se := self.storage.get_encryption_version()) not in (ae := self.get_available_storage_encryption_versions()): raise WalletFileException(f"unexpected storage encryption type. found: {se!r}. allowed: {ae!r}") @@ -494,24 +494,36 @@ async def do_synchronize_loop(self): await run_in_thread(self.synchronize) def save_db(self): - if self.db.storage: - self.db.write() + if self.storage: + self.storage.write() def save_backup(self, backup_dir): - new_path = os.path.join(backup_dir, self.basename() + '.backup') - new_storage = WalletStorage(new_path) - new_storage._encryption_version = self.storage._encryption_version - new_storage.pubkey = self.storage.pubkey - - new_db = WalletDB(self.db.dump(), storage=new_storage, upgrade=True) + import json + from .json_db import JsonDB + from .stored_dict import json_default + # create data + data = self.db.data.dump() if self.lnworker: - channel_backups = new_db.get_dict('imported_channel_backups') + channel_backups = {} for chan_id, chan in self.lnworker.channels.items(): channel_backups[chan_id.hex()] = self.lnworker.create_channel_backup(chan_id) - new_db.put('channels', None) - new_db.put('lightning_privkey2', None) - new_db.set_modified(True) - new_db.write() + data['imported_channel_backups'] = channel_backups + data.pop('channels', None) + data.pop('lightning_privkey2', None) + json_str = json.dumps( + data, + indent=4, + sort_keys=True, + default=json_default, + ) + new_path = os.path.join(backup_dir, self.basename() + '.backup') + new_storage = JsonDB(path=new_path) + if self.storage.is_encrypted(): + new_storage.storage._encryption_version = self.storage.storage._encryption_version + new_storage.storage.pubkey = self.storage.storage.pubkey + new_storage.set_data(json_str) + new_storage.set_modified(True) + new_storage.write_and_force_consolidation() return new_path def has_lightning(self) -> bool: @@ -580,6 +592,8 @@ async def stop(self): self.save_keystore() self.db.prune_uninstalled_plugin_data(self.config.get_installed_plugins()) self.save_db() + if self.storage: + self.storage.close() def is_up_to_date(self) -> bool: if self.taskgroup and self.taskgroup.joined: # either stop() was called, or the taskgroup died @@ -2925,7 +2939,7 @@ def get_formatted_request(self, request_id): def export_request(self, x: Request) -> Dict[str, Any]: key = x.get_id() status = self.get_invoice_status(x) - d = x.as_dict(status) + d = x.export(status) d['request_id'] = d.pop('id') if x.is_lightning(): d['rhash'] = x.rhash @@ -2951,7 +2965,7 @@ def export_request(self, x: Request) -> Dict[str, Any]: def export_invoice(self, x: Invoice) -> Dict[str, Any]: key = x.get_id() status = self.get_invoice_status(x) - d = x.as_dict(status) + d = x.export(status) d['invoice_id'] = d.pop('id') if x.is_lightning(): d['lightning_invoice'] = x.lightning_invoice @@ -3178,7 +3192,9 @@ def update_password(self, old_pw, new_pw, *, encrypt_storage: bool = True, xpub_ if old_pw is None and self.has_password(): raise InvalidPassword() self.check_password(old_pw) - if self.storage: + if self.storage and encrypt_storage: + assert self.storage.supports_file_encryption() + if self.storage and self.storage.supports_file_encryption(): if encrypt_storage: enc_version = StorageEncryptionVersion.XPUB_PASSWORD if xpub_encrypt else StorageEncryptionVersion.USER_PASSWORD assert enc_version in self.get_available_storage_encryption_versions() @@ -3186,7 +3202,7 @@ def update_password(self, old_pw, new_pw, *, encrypt_storage: bool = True, xpub_ enc_version = StorageEncryptionVersion.PLAINTEXT self.storage.set_password(new_pw, enc_version) # make sure next storage.write() saves changes - self.db.set_modified(True) + self.storage.set_modified(True) # note: Encrypting storage with a hw device is currently only # allowed for non-multisig wallets. Further, @@ -3198,7 +3214,7 @@ def update_password(self, old_pw, new_pw, *, encrypt_storage: bool = True, xpub_ self.db.set_keystore_encryption(bool(new_pw) and encrypt_keystore) # save changes. force full rewrite to rm remnants of old password if self.storage and self.storage.file_exists(): - self.db.write_and_force_consolidation() + self.storage.write_and_force_consolidation() # if wallet was previously unlocked, reset password_in_memory self.lock_wallet() @@ -4196,7 +4212,7 @@ def can_enable_disable_keystore(self, ks: KeyStore) -> bool: def enable_keystore(self, keystore: KeyStore, is_hardware_keystore: bool, password) -> None: assert self.can_enable_disable_keystore(keystore) - if not is_hardware_keystore and self.storage.is_encrypted_with_user_pw(): + if not is_hardware_keystore and self.storage and self.storage.is_encrypted_with_user_pw(): keystore.update_password(None, password) self.db.put('use_encryption', True) self._update_keystore(keystore) @@ -4207,7 +4223,7 @@ def disable_keystore(self, keystore: KeyStore) -> None: assert keystore in self.get_keystores() if hasattr(keystore, 'thread') and keystore.thread: keystore.thread.stop() - if self.storage.is_encrypted_with_hw_device(): + if self.storage and self.storage.is_encrypted_with_hw_device(): password = keystore.get_password_for_storage_encryption() self.update_password(password, None, encrypt_storage=False) new = keystore.watching_only_keystore() @@ -4395,23 +4411,26 @@ def wallet_class(wallet_type): def create_new_wallet( - *, - path, - config: SimpleConfig, - passphrase: Optional[str] = None, - password: Optional[str] = None, - encrypt_file: bool = True, - seed_type: Optional[str] = None, - gap_limit: Optional[int] = None, - gap_limit_for_change: Optional[int] = None, + *, + path, + config: SimpleConfig, + passphrase: Optional[str] = None, + password: Optional[str] = None, + encrypt_file: bool = True, + seed_type: Optional[str] = None, + gap_limit: Optional[int] = None, + gap_limit_for_change: Optional[int] = None, + use_levelDB: bool = False, ) -> dict: """Create a new wallet""" - storage = WalletStorage(path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) - if storage.file_exists(): + if os.path.exists(path): raise UserFacingException("Remove the existing wallet first!") + if encrypt_file and use_levelDB: + raise UserFacingException("LevelDB wallets cannot be encrypted") + storage = WalletStorage(path, use_levelDB=use_levelDB, allow_partial_writes=config.WALLET_PARTIAL_WRITES) if encrypt_file: storage.set_password(password, StorageEncryptionVersion.USER_PASSWORD) - db = WalletDB('', storage=storage, upgrade=True) + db = WalletDB(storage.get_stored_dict()) seed = Mnemonic('en').make_seed(seed_type=seed_type) k = keystore.from_seed(seed, passphrase=passphrase) k.update_password(None, password) @@ -4427,36 +4446,43 @@ def create_new_wallet( wallet = Wallet(db, config=config) wallet.synchronize() msg = "Please keep your seed in a safe place; if you lose it, you will not be able to restore your wallet." + if not encrypt_file: + msg += "\nWarning: wallet file not encrypted. Lightning keys will not be encrypted on disk" wallet.save_db() + #storage.close() return {'seed': seed, 'wallet': wallet, 'msg': msg} def restore_wallet_from_text( - text: str, - *, - path: Optional[str], - config: SimpleConfig, - passphrase: Optional[str] = None, - password: Optional[str] = None, - encrypt_file: Optional[bool] = None, - gap_limit: Optional[int] = None, - gap_limit_for_change: Optional[int] = None, - wallet_factory = Wallet, # used in tests + text: str, + *, + path: Optional[str], + config: SimpleConfig, + passphrase: Optional[str] = None, + password: Optional[str] = None, + encrypt_file: bool = True, + gap_limit: Optional[int] = None, + gap_limit_for_change: Optional[int] = None, + use_levelDB: bool = False, + wallet_factory = Wallet, # used in tests ) -> dict: """Restore a wallet from text. Text can be a seed phrase, a master public key, a master private key, a list of bitcoin addresses or bitcoin private keys.""" - if encrypt_file is None: - encrypt_file = True - if path is None: # create wallet in-memory - storage = None + if encrypt_file and use_levelDB: + raise UserFacingException("LevelDB wallets cannot be encrypted") + if path is None: + # tests: create wallet in-memory + storage = WalletStorage(None) + storage.set_data('') else: - storage = WalletStorage(path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) - if storage.file_exists(): + if os.path.exists(path): raise UserFacingException("Remove the existing wallet first!") + storage = WalletStorage(path, use_levelDB=use_levelDB, allow_partial_writes=config.WALLET_PARTIAL_WRITES) if encrypt_file: storage.set_password(password, StorageEncryptionVersion.USER_PASSWORD) - db = WalletDB('', storage=storage, upgrade=True) + + db = WalletDB(storage.get_stored_dict()) db.set_keystore_encryption(bool(password)) text = text.strip() if keystore.is_address_list(text): @@ -4495,10 +4521,11 @@ def restore_wallet_from_text( if gap_limit_for_change is not None: db.put('gap_limit_for_change', gap_limit_for_change) wallet = wallet_factory(db, config=config) - if db.storage: - assert not db.storage.file_exists(), "file was created too soon! plaintext keys might have been written to disk" wallet.synchronize() msg = ("This wallet was restored offline. It may contain more addresses than displayed. " "Start a daemon and use load_wallet to sync its history.") + if not encrypt_file: + msg += "\nWarning: wallet file not encrypted. Lightning keys will not be encrypted on disk." wallet.save_db() + #storage.close() return {'wallet': wallet, 'msg': msg} diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index c53c355b0753..c373b290e848 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -35,22 +35,19 @@ from . import bitcoin from . import constants -from .util import profiler, WalletFileException, multisig_type, TxMinedInfo, MyEncoder +from .util import with_lock as locked +from .util import profiler, WalletFileException, multisig_type, TxMinedInfo from .keystore import bip44_derivation from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput, BadHeaderMagic from .logging import Logger from .lnutil import HTLCOwner, ChannelType, RecvMPPResolution -from .json_db import JsonDB, locked, modifier -from . import stored_dict -from .stored_dict import StoredObject, stored_at, register_key, register_name +from .stored_dict import register_name, register_key +from .stored_dict import StoredObject, StoredDict, StoredList, stored_at from .plugin import run_hook, plugin_loaders from .version import ELECTRUM_VERSION from .i18n import _ -if TYPE_CHECKING: - from .storage import WalletStorage - class WalletRequiresUpgrade(WalletFileException): pass @@ -107,8 +104,13 @@ class WalletFileExceptionVersion51(WalletFileException): pass register_name('transactions/*', None, lambda x: tx_from_any(x, deserialize=False)) register_name('data_loss_protect_remote_pcp/*', None, lambda x: bytes.fromhex(x)) # register tuples, otherwise they will default to StoredList +register_name('closing_height', None, tuple) +register_name('funding_height', None, tuple) +register_name('forwarding_failures/*', None, tuple) +register_name('lightning_payments/*', None, tuple) register_name('contacts/*', None, tuple) register_name('lightning_preimages/*', None, tuple) +register_name('addr_history/*/*', None, tuple) # register dicts that require key conversion for key in [ 'adds', 'locked_in', 'settles', 'fails', 'fee_updates', 'buckets', @@ -121,12 +123,9 @@ class WalletFileExceptionVersion51(WalletFileException): pass class WalletDBUpgrader(Logger): - def __init__(self, data: dict): + def __init__(self, data: StoredDict): Logger.__init__(self) self.data = data - # self.data must be in-memory dict (not a StoredDict or similar), - # so a failed, partial upgrade won't get commited to disk - assert type(self.data) == dict, type(self.data) def get(self, key, default=None): return self.data.get(key, default) @@ -150,10 +149,10 @@ def get_split_accounts(self): wallet_type = self.get('wallet_type') if wallet_type == 'old': assert len(d) == 2 - data1 = copy.deepcopy(self.data) + data1 = copy.deepcopy(self.data.as_dict()) data1['accounts'] = {'0': d['0']} data1['suffix'] = 'deterministic' - data2 = copy.deepcopy(self.data) + data2 = copy.deepcopy(self.data.as_dict()) data2['accounts'] = {'/x': d['/x']} data2['seed'] = None data2['seed_version'] = None @@ -167,11 +166,11 @@ def get_split_accounts(self): mpk = self.get('master_public_keys') for k in d.keys(): i = int(k) - x = d[k] + x = d[k].as_dict() if x.get("pending"): continue xpub = mpk["x/%d'"%i] - new_data = copy.deepcopy(self.data) + new_data = copy.deepcopy(self.data.as_dict()) # save account, derivation and xpub at index 0 new_data['accounts'] = {'0': x} new_data['master_public_keys'] = {"x/0'": xpub} @@ -187,6 +186,7 @@ def requires_upgrade(self): @profiler def upgrade(self): + assert self.data.should_convert() is False self.logger.info('upgrading wallet format') self._convert_imported() self._convert_wallet_type() @@ -267,6 +267,8 @@ def _convert_wallet_type(self): xprvs = self.get('master_private_keys', {}) mpk = self.get('master_public_key') keypairs = self.get('keypairs') + if keypairs: + keypairs = keypairs.as_dict() key_type = self.get('key_type') if seed_version == OLD_SEED_VERSION or wallet_type == 'old': d = { @@ -368,14 +370,14 @@ def _convert_version_14(self): if self.get('wallet_type') =='imported': addresses = self.get('addresses') - if type(addresses) is list: + if type(addresses) is StoredList: addresses = dict([(x, None) for x in addresses]) self.put('addresses', addresses) elif self.get('wallet_type') == 'standard': if self.get('keystore').get('type')=='imported': addresses = set(self.get('addresses').get('receiving')) pubkeys = self.get('keystore').get('keypairs').keys() - assert len(addresses) == len(pubkeys) + assert len(addresses) == len(list(pubkeys)) d = {} for pubkey in pubkeys: addr = bitcoin.pubkey_to_address('p2pkh', pubkey) @@ -427,7 +429,7 @@ def remove_from_list(list_name): if self.get('wallet_type') == 'imported': addresses = self.get('addresses') - assert isinstance(addresses, dict) + assert isinstance(addresses, StoredDict) addresses_new = dict() for address, details in addresses.items(): if not bitcoin.is_address(address): @@ -488,6 +490,7 @@ def _convert_version_20(self): for ks_name in ('keystore', *['x{}/'.format(i) for i in range(1, 16)]): ks = self.get(ks_name, None) if ks is None: continue + assert isinstance(ks, StoredDict) xpub = ks.get('xpub', None) if xpub is None: continue bip32node = BIP32Node.from_xkey(xpub) @@ -512,7 +515,6 @@ def _convert_version_20(self): root_fingerprint = bip32node.fingerprint.hex() ks['root_fingerprint'] = root_fingerprint ks.pop('ckcc_xfp', None) - self.put(ks_name, ks) self.put('seed_version', 20) @@ -587,7 +589,7 @@ def _convert_version_24(self): # convert channels to dict self.data['channels'] = {x['channel_id']: x for x in channels} # convert txi & txo - txi = self.get('txi', {}) + txi = self.data.get_dict('txi') for tx_hash, d in list(txi.items()): d2 = {} for addr, l in d.items(): @@ -595,8 +597,7 @@ def _convert_version_24(self): for ser, v in l: d2[addr][ser] = v txi[tx_hash] = d2 - self.data['txi'] = txi - txo = self.get('txo', {}) + txo = self.data.get_dict('txo') for tx_hash, d in list(txo.items()): d2 = {} for addr, l in d.items(): @@ -604,7 +605,6 @@ def _convert_version_24(self): for n, v, cb in l: d2[addr][str(n)] = (v, cb) txo[tx_hash] = d2 - self.data['txo'] = txo self.data['seed_version'] = 24 @@ -779,10 +779,12 @@ def _convert_version_35(self): if not self._is_upgrade_method_needed(34, 34): return PR_TYPE_ONCHAIN = 0 - requests_old = self.data.get('payment_requests', {}) - requests_new = {k: item for k, item in requests_old.items() - if not (item['type'] == PR_TYPE_ONCHAIN and item['outputs'] is None)} - self.data['payment_requests'] = requests_new + payment_requests = self.data.get('payment_requests', {}) + for k in payment_requests.keys(): + item = payment_requests[k] + if (item['type'] == PR_TYPE_ONCHAIN and item['outputs'] is None): + payment_requests.pop(k) + self.data['seed_version'] = 35 def _convert_version_36(self): @@ -885,13 +887,12 @@ def _convert_version_42(self): def _convert_version_43(self): if not self._is_upgrade_method_needed(42, 42): return - channels = self.data.pop('channels', {}) + channels = self.data.get('channels', {}) for k, c in channels.items(): log = c['log'] c['fail_htlc_reasons'] = log.pop('fail_htlc_reasons', {}) c['unfulfilled_htlcs'] = log.pop('unfulfilled_htlcs', {}) log["1"]['unacked_updates'] = log.pop('unacked_local_updates2', {}) - self.data['channels'] = channels self.data['seed_version'] = 43 def _convert_version_44(self): @@ -921,7 +922,7 @@ def _convert_version_45(self): for key, item in invoices.items(): is_lightning = item['type'] == 2 lightning_invoice = item['invoice'] if is_lightning else None - outputs = item['outputs'] if not is_lightning else None + outputs = item['outputs'][::] if not is_lightning else None bip70 = item['bip70'] if not is_lightning else None if is_lightning: lnaddr = decode_bolt11_invoice(item['invoice']) @@ -957,6 +958,7 @@ def get_id_from_onchain_outputs(raw_outputs, timestamp): outputs = [PartialTxOutput.from_legacy_tuple(*output) for output in raw_outputs] outputs_str = "\n".join(f"{txout.scriptpubkey.hex()}, {txout.value}" for txout in outputs) return sha256d(outputs_str + "%d" % timestamp).hex()[0:10] + assert isinstance(invoices, StoredDict) for key, item in list(invoices.items()): is_lightning = item['lightning_invoice'] is not None if is_lightning: @@ -966,14 +968,15 @@ def get_id_from_onchain_outputs(raw_outputs, timestamp): timestamp = item['time'] newkey = get_id_from_onchain_outputs(outputs_raw, timestamp) if newkey != key: - invoices[newkey] = item + invoices[newkey] = item.as_dict() del invoices[key] def _convert_version_46(self): if not self._is_upgrade_method_needed(45, 45): return - invoices = self.data.get('invoices', {}) - self._convert_invoices_keys(invoices) + invoices = self.data.get('invoices') + if invoices: + self._convert_invoices_keys(invoices) self.data['seed_version'] = 46 def _convert_version_47(self): @@ -1021,8 +1024,9 @@ def _convert_version_49(self): def _convert_version_50(self): if not self._is_upgrade_method_needed(49, 49): return - requests = self.data.get('payment_requests', {}) - self._convert_invoices_keys(requests) + requests = self.data.get('payment_requests') + if requests: + self._convert_invoices_keys(requests) self.data['seed_version'] = 50 def _convert_version_51(self): @@ -1104,7 +1108,9 @@ def _convert_version_55(self): # do not use '/' in dict keys for key in list(self.data.keys()): if key.endswith('/'): - self.data[key[:-1]] = self.data.pop(key) + item = self.data.get(key) + self.data[key[:-1]] = item.as_dict() + self.data.pop(key) self.data['seed_version'] = 55 def _convert_version_56(self): @@ -1159,7 +1165,6 @@ def _convert_version_59(self): for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_key) in chan['unfulfilled_htlcs'].items(): unfulfilled_htlcs[htlc_id] = (onion_packet_hex, forwarding_key or None) chan['unfulfilled_htlcs'] = unfulfilled_htlcs - self.data['channels'] = channels self.data['seed_version'] = 59 def _convert_version_60(self): @@ -1359,7 +1364,7 @@ def _convert_version_67(self): key = '-1' if is_initiator else '1' assert len(chan['log'][key]['fee_updates']) == 1, chan['log'][key]['fee_updates'] chan['log'][key]['fee_updates'] = {} - self.data['channels'] = channels + #self.data['channels'] = channels self.data['seed_version'] = 67 def _convert_version_68(self): @@ -1523,7 +1528,7 @@ def _raise_unsupported_version(self, seed_version): raise WalletFileException(msg) -def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]: +def upgrade_wallet_db(data: 'StoredDict', do_upgrade: bool) -> Tuple[dict, bool]: was_upgraded = False if len(data) == 0: @@ -1536,7 +1541,7 @@ def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]: first_electrum_version_used=ELECTRUM_VERSION, ) assert data.get("db_metadata", None) is None - data["db_metadata"] = v.to_json() + data["db_metadata"] = v.as_dict() was_upgraded = True # Test mainnet/testnet mixup. Do this before DB upgrades, as those might assume # network magic bytes (e.g. if they parse an address or an xpub). @@ -1546,6 +1551,7 @@ def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]: "Current chain: {}").format(constants.net.NET_NAME) ) + data.start_upgrade() dbu = WalletDBUpgrader(data) if dbu.requires_split(): raise WalletRequiresSplit(dbu.get_split_accounts()) @@ -1554,30 +1560,58 @@ def upgrade_wallet_db(data: dict, do_upgrade: bool) -> Tuple[dict, bool]: was_upgraded = True if dbu.requires_upgrade(): raise WalletRequiresUpgrade() - return dbu.data, was_upgraded - - -class WalletDB(JsonDB): - - def __init__( - self, - s: str, - *, - storage: Optional['WalletStorage'] = None, - upgrade: bool = False, - ): - JsonDB.__init__( - self, - s, - storage=storage, - encoder=MyEncoder, - upgrader=partial(upgrade_wallet_db, do_upgrade=upgrade), - ) + data.end_upgrade() + return was_upgraded + + + +@stored_at('txo/*/*/*', tuple) +class TxoValue(NamedTuple): + value: int + is_coinbase: bool + + +class WalletDB(Logger): + + def __init__(self, data: 'StoredDict', upgrade: bool = True): + Logger.__init__(self) + self.data = data + self.lock = self.data.lock + # we must perform db upgrades on the storeddict + was_upgraded = upgrade_wallet_db(self.data, upgrade) + #self._modified |= was_upgraded + # create pointers self.load_transactions() # load plugins that are conditional on wallet type self.load_plugins() + @locked + def put(self, key, value): + # raises if value cannot be serialized by db + if value is not None: + if self.data.get(key) != value: + self.data[key] = copy.deepcopy(value) + return True + elif key in self.data: + self.data.pop(key) + return True + return False + + @locked + def get(self, key, default=None): + return self.data.get(key, default) + + @locked + def get_dict(self, name) -> dict: + return self.data.get_dict(name) + + @locked + def get_stored_item(self, name, default): + if name not in self.data: + self.data[name] = default + return self.data[name] + @locked def get_seed_version(self): return self.get('seed_version') @@ -1612,9 +1646,9 @@ def get_txo_addr(self, tx_hash: str, address: str) -> Dict[int, Tuple[int, bool] assert isinstance(tx_hash, str) assert isinstance(address, str) d = self.txo.get(tx_hash, {}).get(address, {}) - return {int(n): (v, cb) for (n, (v, cb)) in d.items()} + return {int(n): (item.value, item.is_coinbase) for (n, item) in d.items()} - @modifier + @locked def add_txi_addr(self, tx_hash: str, addr: str, ser: str, v: int) -> None: assert isinstance(tx_hash, str) assert isinstance(addr, str) @@ -1627,7 +1661,7 @@ def add_txi_addr(self, tx_hash: str, addr: str, ser: str, v: int) -> None: d[addr] = {} d[addr][ser] = v - @modifier + @locked def add_txo_addr(self, tx_hash: str, addr: str, n: Union[int, str], v: int, is_coinbase: bool) -> None: n = str(n) assert isinstance(tx_hash, str) @@ -1640,7 +1674,7 @@ def add_txo_addr(self, tx_hash: str, addr: str, n: Union[int, str], v: int, is_c d = self.txo[tx_hash] if addr not in d: d[addr] = {} - d[addr][n] = (v, is_coinbase) + d[addr][n] = TxoValue(v, is_coinbase) @locked def list_txi(self) -> Sequence[str]: @@ -1650,12 +1684,12 @@ def list_txi(self) -> Sequence[str]: def list_txo(self) -> Sequence[str]: return list(self.txo.keys()) - @modifier + @locked def remove_txi(self, tx_hash: str) -> None: assert isinstance(tx_hash, str) self.txi.pop(tx_hash, None) - @modifier + @locked def remove_txo(self, tx_hash: str) -> None: assert isinstance(tx_hash, str) self.txo.pop(tx_hash, None) @@ -1678,7 +1712,7 @@ def get_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str]) -> O prevout_n = str(prevout_n) return self.spent_outpoints.get(prevout_hash, {}).get(prevout_n) - @modifier + @locked def remove_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str]) -> None: assert isinstance(prevout_hash, str) prevout_n = str(prevout_n) @@ -1686,7 +1720,7 @@ def remove_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str]) - if not self.spent_outpoints[prevout_hash]: self.spent_outpoints.pop(prevout_hash) - @modifier + @locked def set_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str], tx_hash: str) -> None: assert isinstance(prevout_hash, str) assert isinstance(tx_hash, str) @@ -1695,7 +1729,7 @@ def set_spent_outpoint(self, prevout_hash: str, prevout_n: Union[int, str], tx_h self.spent_outpoints[prevout_hash] = {} self.spent_outpoints[prevout_hash][prevout_n] = tx_hash - @modifier + @locked def add_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None: assert isinstance(scripthash, str) assert isinstance(prevout, TxOutpoint) @@ -1704,7 +1738,7 @@ def add_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, val self._prevouts_by_scripthash[scripthash] = dict() self._prevouts_by_scripthash[scripthash][prevout.to_str()] = value - @modifier + @locked def remove_prevout_by_scripthash(self, scripthash: str, *, prevout: TxOutpoint, value: int) -> None: assert isinstance(scripthash, str) assert isinstance(prevout, TxOutpoint) @@ -1719,7 +1753,7 @@ def get_prevouts_by_scripthash(self, scripthash: str) -> Set[Tuple[TxOutpoint, i prevouts_and_values = self._prevouts_by_scripthash.get(scripthash, {}) return {(TxOutpoint.from_str(prevout), value) for prevout, value in prevouts_and_values.items()} - @modifier + @locked def add_transaction(self, tx_hash: str, tx: Transaction) -> None: assert isinstance(tx_hash, str) assert isinstance(tx, Transaction), tx @@ -1735,7 +1769,7 @@ def add_transaction(self, tx_hash: str, tx: Transaction) -> None: if tx_we_already_have is None or isinstance(tx_we_already_have, PartialTransaction): self.transactions[tx_hash] = tx - @modifier + @locked def remove_transaction(self, tx_hash: str) -> Optional[Transaction]: assert isinstance(tx_hash, str) return self.transactions.pop(tx_hash, None) @@ -1765,12 +1799,12 @@ def get_addr_history(self, addr: str) -> Sequence[Tuple[str, int]]: assert isinstance(addr, str) return self.history.get(addr, []) - @modifier + @locked def set_addr_history(self, addr: str, hist) -> None: assert isinstance(addr, str) self.history[addr] = hist - @modifier + @locked def remove_addr_history(self, addr: str) -> None: assert isinstance(addr, str) self.history.pop(addr, None) @@ -1784,22 +1818,17 @@ def get_verified_tx(self, txid: str) -> Optional[TxMinedInfo]: assert isinstance(txid, str) if txid not in self.verified_tx: return None - height, timestamp, txpos, header_hash = self.verified_tx[txid] - return TxMinedInfo(_height=height, - conf=None, - timestamp=timestamp, - txpos=txpos, - header_hash=header_hash) - - @modifier + return self.verified_tx[txid] + + @locked def add_verified_tx(self, txid: str, info: TxMinedInfo): assert isinstance(txid, str) assert isinstance(info, TxMinedInfo) height = info._height # number of conf is dynamic and might not be set here assert height > 0, height - self.verified_tx[txid] = (height, info.timestamp, info.txpos, info.header_hash) + self.verified_tx[txid] = info - @modifier + @locked def remove_verified_tx(self, txid: str): assert isinstance(txid, str) self.verified_tx.pop(txid, None) @@ -1808,7 +1837,7 @@ def is_in_verified_tx(self, txid: str) -> bool: assert isinstance(txid, str) return txid in self.verified_tx - @modifier + @locked def add_tx_fee_from_server(self, txid: str, fee_sat: Optional[int]) -> None: assert isinstance(txid, str) # note: when called with (fee_sat is None), rm currently saved value @@ -1819,7 +1848,7 @@ def add_tx_fee_from_server(self, txid: str, fee_sat: Optional[int]) -> None: return self.tx_fees[txid] = tx_fees_value._replace(fee=fee_sat, is_calculated_by_us=False) - @modifier + @locked def add_tx_fee_we_calculated(self, txid: str, fee_sat: Optional[int]) -> None: assert isinstance(txid, str) if fee_sat is None: @@ -1840,7 +1869,7 @@ def get_tx_fee(self, txid: str, *, trust_server: bool = False) -> Optional[int]: return None return tx_fees_value.fee - @modifier + @locked def add_num_inputs_to_tx(self, txid: str, num_inputs: int) -> None: assert isinstance(txid, str) assert isinstance(num_inputs, int) @@ -1862,7 +1891,7 @@ def get_num_ismine_inputs_of_tx(self, txid: str) -> int: txins = self.txi.get(txid, {}) return sum([len(tupls) for addr, tupls in txins.items()]) - @modifier + @locked def remove_tx_fee(self, txid: str) -> None: assert isinstance(txid, str) self.tx_fees.pop(txid, None) @@ -1885,13 +1914,13 @@ def get_receiving_addresses(self, *, slice_start=None, slice_stop=None) -> List[ # note: slicing makes a shallow copy return self.receiving_addresses[slice_start:slice_stop] - @modifier + @locked def add_change_address(self, addr: str) -> None: assert isinstance(addr, str) self._addr_to_addr_index[addr] = (1, len(self.change_addresses)) self.change_addresses.append(addr) - @modifier + @locked def add_receiving_address(self, addr: str) -> None: assert isinstance(addr, str) self._addr_to_addr_index[addr] = (0, len(self.receiving_addresses)) @@ -1902,12 +1931,12 @@ def get_address_index(self, address: str) -> Optional[Sequence[int]]: assert isinstance(address, str) return self._addr_to_addr_index.get(address) - @modifier + @locked def add_imported_address(self, addr: str, d: dict) -> None: assert isinstance(addr, str) self.imported_addresses[addr] = d - @modifier + @locked def remove_imported_address(self, addr: str) -> None: assert isinstance(addr, str) self.imported_addresses.pop(addr) @@ -1931,10 +1960,10 @@ def load_addresses(self, wallet_type): if wallet_type == 'imported': self.imported_addresses = self.get_dict('addresses') # type: Dict[str, dict] else: - self.get_dict('addresses') + addresses = self.get_dict('addresses') for name in ['receiving', 'change']: - if name not in self.data['addresses']: - self.data['addresses'][name] = [] + if name not in addresses: + addresses[name] = [] self.change_addresses = self.data['addresses']['change'] self.receiving_addresses = self.data['addresses']['receiving'] self._addr_to_addr_index = {} # type: Dict[str, Sequence[int]] # key: address, value: (is_change, index) @@ -1971,7 +2000,7 @@ def load_transactions(self): self.logger.info("removing unreferenced spent outpoint") d.pop(prevout_n) - @modifier + @locked def clear_history(self): self.txi.clear() self.txo.clear() @@ -1982,23 +2011,17 @@ def clear_history(self): self.tx_fees.clear() self._prevouts_by_scripthash.clear() - def _should_convert_to_stored_dict(self, key) -> bool: - if key == 'keystore': - return False - multisig_keystore_names = [('x%d' % i) for i in range(1, 16)] - if key in multisig_keystore_names: - return False - return True - @classmethod def split_accounts(klass, root_path, split_data): - from .storage import WalletStorage + # not covered by tests + from .json_db import JsonDB file_list = [] for data in split_data: path = root_path + '.' + data['suffix'] - item_storage = WalletStorage(path) - db = WalletDB(json.dumps(data), storage=item_storage, upgrade=True) - db.write() + storage = JsonDB(path) + storage.set_data(json.dumps(data)) + db = WalletDB(storage.get_stored_dict(), upgrade=True) + storage.write() file_list.append(path) return file_list diff --git a/electrum/wizard.py b/electrum/wizard.py index 6b2e7b65d2cd..b1f0e087f462 100644 --- a/electrum/wizard.py +++ b/electrum/wizard.py @@ -12,7 +12,8 @@ from electrum.network import ProxySettings from electrum.plugin import run_hook from electrum.slip39 import EncryptedSeed -from electrum.storage import WalletStorage, StorageEncryptionVersion, StorageReadWriteError +from electrum.storage import StorageEncryptionVersion, StorageReadWriteError +from electrum.stored_dict import WalletStorage from electrum.util import UserFacingException from electrum.wallet_db import WalletDB from electrum.bip32 import normalize_bip32_derivation, xpub_type @@ -780,7 +781,7 @@ def create_storage(self, path: str, data: dict): enc_version = StorageEncryptionVersion.USER_PASSWORD storage.set_password(data['password'], enc_version=enc_version) - db = WalletDB('', storage=storage, upgrade=True) + db = WalletDB(storage.get_stored_dict()) db.set_keystore_encryption(bool(data['password'])) db.put('wallet_type', data['wallet_type']) @@ -823,7 +824,8 @@ def create_storage(self, path: str, data: dict): db.put('lightning_xprv', k.get_lightning_xprv(data['password'])) db.load_plugins() - db.write() + storage.write() + storage.close() class ServerConnectWizard(AbstractWizard): diff --git a/run_electrum b/run_electrum index 42f77a3c3dd3..691c165ed6a4 100755 --- a/run_electrum +++ b/run_electrum @@ -124,7 +124,7 @@ from electrum.payment_identifier import PaymentIdentifier from electrum import SimpleConfig from electrum.wallet_db import WalletDB from electrum.wallet import Wallet -from electrum.storage import WalletStorage +from electrum.stored_dict import WalletStorage from electrum.util import print_msg, print_stderr, json_encode, json_decode, UserCancelled from electrum.util import InvalidPassword from electrum.plugin import Plugins @@ -167,8 +167,8 @@ def init_cmdline(config_options, wallet_path, *, rpcserver: bool, config: 'Simpl print_msg("wallet path not provided.") sys_exit(1) - # instantiate wallet for command-line - storage = WalletStorage(wallet_path, allow_partial_writes=config.WALLET_PARTIAL_WRITES) if wallet_path else None + # instantiate storage without opening the DB, so that we can check if it is encrypted + storage = WalletStorage(wallet_path, init_db=False) if wallet_path else None if cmd.requires_wallet and not storage.file_exists(): print_msg("Error: Wallet file not found.") @@ -256,7 +256,7 @@ async def run_offline_command(config: 'SimpleConfig', config_options: dict, wall password = get_password_for_hw_device_encrypted_storage(plugins) config_options['password'] = password storage.decrypt(password) - db = WalletDB(storage.read(), storage=storage, upgrade=True) + db = WalletDB(storage.get_stored_dict()) wallet = Wallet(db, config=config) config_options['wallet'] = wallet else: diff --git a/tests/plugins/test_timelock_recovery.py b/tests/plugins/test_timelock_recovery.py index 8c57b19c6d28..dc83647baea6 100644 --- a/tests/plugins/test_timelock_recovery.py +++ b/tests/plugins/test_timelock_recovery.py @@ -5,8 +5,8 @@ from electrum.bitcoin import address_to_script from electrum.fee_policy import FixedFeePolicy from electrum.simple_config import SimpleConfig -from electrum.storage import WalletStorage from electrum.transaction import PartialTxOutput +from electrum.stored_dict import WalletStorage from electrum.wallet import Wallet from electrum.wallet_db import WalletDB @@ -37,7 +37,8 @@ def _create_default_wallet(self): with open(os.path.join(os.path.dirname(__file__), "test_timelock_recovery", "default_wallet"), "r") as f: wallet_str = f.read() storage = WalletStorage(self.wallet_path) - db = WalletDB(wallet_str, storage=storage, upgrade=True) + storage.set_data(wallet_str) + db = WalletDB(storage.get_stored_dict(), upgrade=True) wallet = Wallet(db, config=self.config) return wallet diff --git a/tests/regtest/regtest.sh b/tests/regtest/regtest.sh index 0dcadfed9b1d..4ece6ed5f596 100755 --- a/tests/regtest/regtest.sh +++ b/tests/regtest/regtest.sh @@ -4,8 +4,9 @@ set -eu TEST_SRK_CHANNELS=False -# alice -> bob -> carol +USE_LEVELDB="--use_levelDB --encrypt_file=false" +# alice -> bob -> carol alice="./run_electrum --regtest -D /tmp/alice" bob="./run_electrum --regtest -D /tmp/bob" carol="./run_electrum --regtest -D /tmp/carol" @@ -190,7 +191,7 @@ if [[ $1 == "init" ]]; then echo "initializing $2" rm -rf /tmp/$2/ agent="./run_electrum --regtest -D /tmp/$2" - $agent create --offline > /dev/null + $agent create --offline $USE_LEVELDB > /dev/null $agent setconfig --offline test_ln_open_srk_channels $TEST_SRK_CHANNELS $agent setconfig --offline log_to_file True $agent setconfig --offline use_gossip True @@ -687,7 +688,7 @@ if [[ $1 == "breach_with_spent_htlc" ]]; then echo "enable_htlc_settle did not work, $unsettled" exit 1 fi - cp /tmp/alice/regtest/wallets/default_wallet /tmp/alice/regtest/wallets/toxic_wallet + cp -r /tmp/alice/regtest/wallets/default_wallet /tmp/alice/regtest/wallets/toxic_wallet $bob enable_htlc_settle true unsettled=$($alice list_channels | jq '.[] | .local_unsettled_sent') if [[ "$unsettled" != "0" ]]; then diff --git a/tests/test_bitcoin.py b/tests/test_bitcoin.py index 78a6b523d96a..ee8e3f6ddff0 100644 --- a/tests/test_bitcoin.py +++ b/tests/test_bitcoin.py @@ -28,7 +28,7 @@ from electrum.crypto import sha256d, SUPPORTED_PW_HASH_VERSIONS from electrum import crypto, constants from electrum.util import bfh, InvalidPassword, randrange -from electrum.storage import WalletStorage +from electrum.storage import FileStorage from electrum.keystore import xtype_from_derivation from . import ElectrumTestCase @@ -270,7 +270,7 @@ def test_signmessage_segwit_witness_v0_address_test_we_also_accept_sigs_from_tre @needs_test_with_all_aes_implementations def test_decrypt_message(self): - key = WalletStorage.get_eckey_from_password('pw123') + key = FileStorage.get_eckey_from_password('pw123') self.assertEqual(b'me<(s_s)>age', crypto.ecies_decrypt_message( key, b'QklFMQMDFtgT3zWSQsa+Uie8H/WvfUjlu9UN9OJtTt3KlgKeSTi6SQfuhcg1uIz9hp3WIUOFGTLr4RNQBdjPNqzXwhkcPi2Xsbiw6UCNJncVPJ6QBg==')) self.assertEqual(b'me<(s_s)>age', crypto.ecies_decrypt_message( @@ -280,7 +280,7 @@ def test_decrypt_message(self): @needs_test_with_all_aes_implementations def test_encrypt_message(self): - key = WalletStorage.get_eckey_from_password('secret_password77') + key = FileStorage.get_eckey_from_password('secret_password77') msgs = [ bytes([0] * 555), b'cannot think of anything funny' diff --git a/tests/test_commands.py b/tests/test_commands.py index f294c6da5b58..dda588c4d844 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -11,7 +11,7 @@ import electrum from electrum.commands import Commands, eval_bool from electrum import storage, wallet -from electrum.lnutil import RECEIVED +from electrum.lnutil import RECEIVED, ReceivedMPPStatus, UpdateAddHtlc, ReceivedMPPHtlc from electrum.lnworker import RecvMPPResolution from electrum.wallet import Abstract_Wallet from electrum.address_synchronizer import TX_HEIGHT_UNCONFIRMED @@ -549,16 +549,28 @@ async def test_hold_invoice_commands(self, mock_save_db): wallet=wallet, ) - mock_htlc1 = mock.Mock() - mock_htlc1.htlc.cltv_abs = 800_000 - mock_htlc1.htlc.amount_msat = 4_500_000 - mock_htlc2 = mock.Mock() - mock_htlc2.htlc.cltv_abs = 800_144 - mock_htlc2.htlc.amount_msat = 5_500_000 - mock_htlc_status = mock.Mock() - mock_htlc_status.htlcs = [mock_htlc1, mock_htlc2] - mock_htlc_status.resolution = RecvMPPResolution.COMPLETE - + mock_htlc1 = ReceivedMPPHtlc( + channel_id='', + htlc = UpdateAddHtlc( + cltv_abs = 800_000, + amount_msat = 4_500_000, + payment_hash=bytes(32), + ), + unprocessed_onion='', + ) + mock_htlc2 = ReceivedMPPHtlc( + channel_id = '', + htlc = UpdateAddHtlc( + cltv_abs = 800_144, + amount_msat = 5_500_000, + payment_hash=bytes(32), + ), + unprocessed_onion = '', + ) + mock_htlc_status = ReceivedMPPStatus( + htlcs = [mock_htlc1, mock_htlc2], + resolution = RecvMPPResolution.COMPLETE, + ) payment_key = wallet.lnworker._get_payment_key(bytes.fromhex(payment_hash)).hex() with mock.patch.dict(wallet.lnworker.received_mpp_htlcs, {payment_key: mock_htlc_status}): status: dict = await cmds.check_hold_invoice(payment_hash=payment_hash, wallet=wallet) @@ -587,8 +599,8 @@ async def test_hold_invoice_commands(self, mock_save_db): # cancelling a settled invoice should raise await cmds.cancel_hold_invoice(payment_hash=payment_hash, wallet=wallet) - @mock.patch.object(storage.WalletStorage, 'write') - @mock.patch.object(storage.WalletStorage, 'append') + @mock.patch.object(storage.FileStorage, 'write') + @mock.patch.object(storage.FileStorage, 'append') async def test_onchain_history(self, *mock_args): cmds = Commands(config=self.config, daemon=self.daemon) wallet_path = self.get_wallet_file_path("client_3_3_8_xpub_with_realistic_history") diff --git a/tests/test_jsondb.py b/tests/test_jsondb.py index bb07eaf1a286..e128dc7e16e0 100644 --- a/tests/test_jsondb.py +++ b/tests/test_jsondb.py @@ -10,7 +10,7 @@ from . import ElectrumTestCase -from electrum.json_db import JsonDB +from electrum.stored_dict import WalletStorage class TestJsonpatch(ElectrumTestCase): @@ -123,13 +123,16 @@ async def test_jsondb_replace_after_remove(self): for pop_from_dict in [pop1_from_dict, pop2_from_dict]: with self.subTest(pop_from_dict): data = { 'a': {'b': {'c': 0}}, 'd': 3} - db = JsonDB(repr(data)) - a = db.get_dict('a') + db = WalletStorage(None) + db.set_data(repr(data)) + sd = db.get_stored_dict() + a = sd.get_dict('a') # remove b = pop_from_dict(a, 'b') self.assertEqual(len(db.pending_changes), 1) # replace item. this must not been written to db - b['c'] = 42 + with self.assertRaises(KeyError): + b['c'] = 42 self.assertEqual(len(db.pending_changes), 1) patches = json.loads('[' + ','.join(db.pending_changes) + ']') jpatch = jsonpatch.JsonPatch(patches) @@ -140,13 +143,16 @@ async def test_jsondb_replace_after_remove_nested(self): for pop_from_dict in [pop1_from_dict, pop2_from_dict]: with self.subTest(pop_from_dict): data = { 'a': {'b': {'c': 0}}, 'd': 3} - db = JsonDB(repr(data)) + db = WalletStorage(None) + db.set_data(repr(data)) + sd = db.get_stored_dict() # remove - a = pop_from_dict(db.data, "a") + a = pop_from_dict(sd, "a") self.assertEqual(len(db.pending_changes), 1) - b = a['b'] - # replace item. this must not be written to db - b['c'] = 42 + with self.assertRaises(KeyError): + b = a['b'] + # replace item. this must not be written to db + b['c'] = 42 self.assertEqual(len(db.pending_changes), 1) patches = json.loads('[' + ','.join(db.pending_changes) + ']') jpatch = jsonpatch.JsonPatch(patches) diff --git a/tests/test_lnchannel.py b/tests/test_lnchannel.py index 972855fbe65d..81eb0b3d6404 100644 --- a/tests/test_lnchannel.py +++ b/tests/test_lnchannel.py @@ -45,7 +45,6 @@ ) from electrum.logging import console_stderr_handler from electrum.lnchannel import ChannelState, Channel -from electrum.json_db import StoredDict from electrum.coinchooser import PRNG from . import ElectrumTestCase @@ -119,7 +118,7 @@ def create_channel_state( 'revocation_store': {}, 'channel_type': channel_type, } - return StoredDict(state, None) + return state def create_test_channels( diff --git a/tests/test_lnhtlc.py b/tests/test_lnhtlc.py index c8ac53db39b8..5cac3dd6bbbc 100644 --- a/tests/test_lnhtlc.py +++ b/tests/test_lnhtlc.py @@ -4,7 +4,6 @@ from electrum.lnutil import RECEIVED, LOCAL, REMOTE, SENT, HTLCOwner, Direction from electrum.lnhtlc import HTLCManager -from electrum.json_db import StoredDict from . import ElectrumTestCase @@ -14,8 +13,8 @@ class H(NamedTuple): class TestHTLCManager(ElectrumTestCase): def test_adding_htlcs_race(self): - A = HTLCManager(StoredDict({}, None)) - B = HTLCManager(StoredDict({}, None)) + A = HTLCManager({}) + B = HTLCManager({}) A.channel_open_finished() B.channel_open_finished() ah0, bh0 = H('A', 0), H('B', 0) @@ -61,8 +60,8 @@ def test_adding_htlcs_race(self): def test_single_htlc_full_lifecycle(self): def htlc_lifecycle(htlc_success: bool): - A = HTLCManager(StoredDict({}, None)) - B = HTLCManager(StoredDict({}, None)) + A = HTLCManager({}) + B = HTLCManager({}) A.channel_open_finished() B.channel_open_finished() B.recv_htlc(A.send_htlc(H('A', 0))) @@ -134,8 +133,8 @@ def htlc_lifecycle(htlc_success: bool): def test_remove_htlc_while_owing_commitment(self): def htlc_lifecycle(htlc_success: bool): - A = HTLCManager(StoredDict({}, None)) - B = HTLCManager(StoredDict({}, None)) + A = HTLCManager({}) + B = HTLCManager({}) A.channel_open_finished() B.channel_open_finished() ah0 = H('A', 0) @@ -171,8 +170,8 @@ def htlc_lifecycle(htlc_success: bool): htlc_lifecycle(htlc_success=False) def test_adding_htlc_between_send_ctx_and_recv_rev(self): - A = HTLCManager(StoredDict({}, None)) - B = HTLCManager(StoredDict({}, None)) + A = HTLCManager({}) + B = HTLCManager({}) A.channel_open_finished() B.channel_open_finished() A.send_ctx() @@ -217,8 +216,8 @@ def test_adding_htlc_between_send_ctx_and_recv_rev(self): self.assertEqual([(Direction.RECEIVED, ah0)], A.get_htlcs_in_next_ctx(REMOTE)) def test_unacked_local_updates(self): - A = HTLCManager(StoredDict({}, None)) - B = HTLCManager(StoredDict({}, None)) + A = HTLCManager({}) + B = HTLCManager({}) A.channel_open_finished() B.channel_open_finished() self.assertEqual({}, A.get_unacked_local_updates()) diff --git a/tests/test_lnutil.py b/tests/test_lnutil.py index 5df8755ce524..45ab22516a13 100644 --- a/tests/test_lnutil.py +++ b/tests/test_lnutil.py @@ -3,7 +3,6 @@ from typing import Dict, List from electrum import bitcoin -from electrum.json_db import StoredDict from electrum.lnutil import ( RevocationStore, get_per_commitment_secret_from_seed, make_offered_htlc, make_received_htlc, make_commitment, make_htlc_tx_witness, make_htlc_tx_output, make_htlc_tx_inputs, secret_to_pubkey, derive_blinded_pubkey, @@ -12,7 +11,8 @@ IncompatibleLightningFeatures, ChannelType, offered_htlc_trim_threshold_sat, received_htlc_trim_threshold_sat, ImportedChannelBackupStorage, list_enabled_ln_feature_bits, PaymentFeeBudget, LnFeatureContexts ) -from electrum.util import bfh, MyEncoder +from electrum.util import bfh +from electrum.stored_dict import json_default, WalletStorage from electrum.transaction import Transaction, PartialTransaction, Sighash from electrum.lnworker import LNWallet from electrum.wallet import Standard_Wallet @@ -474,10 +474,10 @@ def test_shachain_store(self): ] for test in tests: - receiver = RevocationStore(StoredDict({}, None)) + storage = WalletStorage(None) + receiver = RevocationStore(storage.get_stored_dict()) for insert in test["inserts"]: secret = bytes.fromhex(insert["secret"]) - try: receiver.add_next_entry(secret) except Exception as e: @@ -497,7 +497,8 @@ def test_shachain_store(self): def test_shachain_produce_consume(self): seed = bitcoin.sha256(b"shachaintest") - consumer = RevocationStore(StoredDict({}, None)) + storage = WalletStorage(None) + consumer = RevocationStore(storage.get_stored_dict()) for i in range(10000): secret = get_per_commitment_secret_from_seed(seed, RevocationStore.START_INDEX - i) try: @@ -506,9 +507,11 @@ def test_shachain_produce_consume(self): raise Exception("iteration " + str(i) + ": " + str(e)) if i % 1000 == 0: c1 = consumer - s1 = json.dumps(c1.storage, cls=MyEncoder) - c2 = RevocationStore(StoredDict(json.loads(s1), None)) - s2 = json.dumps(c2.storage, cls=MyEncoder) + s1 = json.dumps(storage.json_data, default=json_default) + storage2 = WalletStorage(None) + storage2.set_data(s1) + c2 = RevocationStore(storage2.get_stored_dict()) + s2 = json.dumps(storage2.json_data, default=json_default) self.assertEqual(s1, s2) def test_commitment_tx_with_all_five_HTLCs_untrimmed_minimum_feerate(self): diff --git a/tests/test_storage_upgrade.py b/tests/test_storage_upgrade.py index 4be5cbe66517..cfc0914c582b 100644 --- a/tests/test_storage_upgrade.py +++ b/tests/test_storage_upgrade.py @@ -7,6 +7,8 @@ import inspect import electrum +from electrum.stored_dict import WalletStorage +from electrum.stored_dict import StoredDict from electrum.wallet_db import WalletDBUpgrader, WalletDB, WalletRequiresUpgrade, WalletRequiresSplit from electrum.wallet import Wallet from electrum import constants @@ -358,7 +360,9 @@ async def _upgrade_storage(self, wallet_json, accounts=1) -> Optional[WalletDB]: self.assertEqual(accounts, len(split_data)) for item in split_data: data = json.dumps(item) - new_db = WalletDB(data, storage=None, upgrade=True) + storage = WalletStorage(None) + storage.set_data(data) + new_db = WalletDB(storage.get_stored_dict(), upgrade=True) await self._sanity_check_upgraded_db(new_db) async def _sanity_check_upgraded_db(self, db): @@ -367,5 +371,7 @@ async def _sanity_check_upgraded_db(self, db): @staticmethod def _load_db_from_json_string(*, wallet_json, upgrade): - db = WalletDB(wallet_json, storage=None, upgrade=upgrade) + storage = WalletStorage(None) + storage.set_data(wallet_json) + db = WalletDB(storage.get_stored_dict(), upgrade=upgrade) return db diff --git a/tests/test_wallet.py b/tests/test_wallet.py index de1a278a6ce1..44c65abf7fe6 100644 --- a/tests/test_wallet.py +++ b/tests/test_wallet.py @@ -10,14 +10,14 @@ from unittest import mock from pathlib import Path -from electrum.storage import WalletStorage from electrum.wallet_db import FINAL_SEED_VERSION from electrum.wallet import (Abstract_Wallet, Standard_Wallet, create_new_wallet, Imported_Wallet, Wallet) from electrum.exchange_rate import ExchangeBase, FxThread from electrum.util import TxMinedInfo, InvalidPassword from electrum.bitcoin import COIN -from electrum.wallet_db import WalletDB, JsonDB +from electrum.wallet_db import WalletDB +from electrum.stored_dict import WalletStorage from electrum.simple_config import SimpleConfig from electrum import util, storage from electrum.daemon import Daemon @@ -66,15 +66,13 @@ def test_read_dictionary_from_file(self): with open(self.wallet_path, "w") as f: contents = f.write(contents) - storage = WalletStorage(self.wallet_path) - db = JsonDB(storage.read(), storage=storage) - self.assertEqual("b", db.get("a")) - self.assertEqual("d", db.get("c")) + db = WalletStorage(self.wallet_path) + self.assertEqual("b", db.get([''], "a")) + self.assertEqual("d", db.get([''], "c")) def test_write_dictionary_to_file(self): - storage = WalletStorage(self.wallet_path) - db = JsonDB('', storage=storage) + db = WalletStorage(self.wallet_path) some_dict = { u"a": u"b", @@ -82,7 +80,7 @@ def test_write_dictionary_to_file(self): u"seed_version": FINAL_SEED_VERSION} for key, value in some_dict.items(): - db.put(key, value) + db.put([''], key, value) db.write() with open(self.wallet_path, "r") as f: @@ -91,6 +89,24 @@ def test_write_dictionary_to_file(self): for key, value in some_dict.items(): self.assertEqual(d[key], value) + def _test_db_roundtrip(self, use_levelDB): + db = WalletStorage(self.wallet_path, use_levelDB=use_levelDB) + sd = db.get_stored_dict() + # list containing list and dict + some_list = [[1, 2], {"c": "d"} ] + sd['1'] = some_list + self.assertEqual(sd['1']._dump(), some_list) + # dict containing list and dict + some_dict = {"a": [1, 2], "b": {"c":"d"} } + sd['2'] = some_dict + self.assertEqual(sd['2']._dump(), some_dict) + + def test_jsondb_roundtrip(self): + self._test_db_roundtrip(False) + + def test_leveldb_roundtrip(self): + self._test_db_roundtrip(True) + async def test_storage_imported_add_privkeys_persistence_test(self): text = ' '.join([ 'p2wpkh:L4jkdiXszG26SUYvwwJhzGwg37H2nLhrbip7u6crmgNeJysv5FHL', @@ -172,7 +188,8 @@ class FakeWallet: def __init__(self, fiat_value): super().__init__() self.fiat_value = fiat_value - self.db = WalletDB('', storage=None, upgrade=False) + storage = WalletStorage(None) + self.db = WalletDB(storage.get_stored_dict()) self.adb = FakeADB() self.db.transactions = self.db.verified_tx = {'abc':'Tx'} @@ -234,8 +251,8 @@ def tearDown(self): time.tzset() @mock.patch('electrum.wallet.run_hook') - @mock.patch.object(storage.WalletStorage, 'write') - @mock.patch.object(storage.WalletStorage, 'append') + @mock.patch.object(storage.FileStorage, 'write') + @mock.patch.object(storage.FileStorage, 'append') async def test_export_history_to_file(self, _mock_append, _mock_write, mock_run_hook): # prepare wallet with realistic history c = self.config @@ -325,7 +342,7 @@ async def test_restore_wallet_from_text_no_storage(self): config=self.config, ) wallet = d['wallet'] # type: Standard_Wallet - self.assertEqual(None, wallet.storage) + self.assertEqual(None, wallet.storage.storage) self.assertEqual(text, wallet.keystore.get_seed(None)) self.assertEqual('bc1q3g5tmkmlvxryhh843v4dz026avatc0zzr6h3af', wallet.get_receiving_addresses()[0]) @@ -379,7 +396,8 @@ class TestWalletPassword(WalletTestCase): async def test_update_password_of_imported_wallet(self): wallet_str = '{"addr_history":{"1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr":[],"15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA":[],"1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6":[]},"addresses":{"change":[],"receiving":["1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr","1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6","15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA"]},"keystore":{"keypairs":{"0344b1588589958b0bcab03435061539e9bcf54677c104904044e4f8901f4ebdf5":"L2sED74axVXC4H8szBJ4rQJrkfem7UMc6usLCPUoEWxDCFGUaGUM","0389508c13999d08ffae0f434a085f4185922d64765c0bff2f66e36ad7f745cc5f":"L3Gi6EQLvYw8gEEUckmqawkevfj9s8hxoQDFveQJGZHTfyWnbk1U","04575f52b82f159fa649d2a4c353eb7435f30206f0a6cb9674fbd659f45082c37d559ffd19bea9c0d3b7dcc07a7b79f4cffb76026d5d4dff35341efe99056e22d2":"5JyVyXU1LiRXATvRTQvR9Kp8Rx1X84j2x49iGkjSsXipydtByUq"},"type":"imported"},"pruned_txo":{},"seed_version":13,"stored_height":-1,"transactions":{},"tx_fees":{},"txi":{},"txo":{},"use_encryption":false,"verified_tx3":{},"wallet_type":"standard","winpos-qt":[100,100,840,405]}' storage = WalletStorage(self.wallet_path) - db = WalletDB(wallet_str, storage=storage, upgrade=True) + storage.set_data(wallet_str) + db = WalletDB(storage.get_stored_dict()) wallet = Wallet(db, config=self.config) wallet.check_password(None) @@ -394,8 +412,9 @@ async def test_update_password_of_imported_wallet(self): async def test_update_password_of_standard_wallet(self): wallet_str = '''{"addr_history":{"12ECgkzK6gHouKAZ7QiooYBuk1CgJLJxes":[],"12iR43FPb5M7sw4Mcrr5y1nHKepg9EtZP1":[],"13HT1pfWctsSXVFzF76uYuVdQvcAQ2MAgB":[],"13kG9WH9JqS7hyCcVL1ssLdNv4aXocQY9c":[],"14Tf3qiiHJXStSU4KmienAhHfHq7FHpBpz":[],"14gmBxYV97mzYwWdJSJ3MTLbTHVegaKrcA":[],"15FGuHvRssu1r8fCw98vrbpfc3M4xs5FAV":[],"17oJzweA2gn6SDjsKgA9vUD5ocT1sSnr2Z":[],"18hNcSjZzRcRP6J2bfFRxp9UfpMoC4hGTv":[],"18n9PFxBjmKCGhd4PCDEEqYsi2CsnEfn2B":[],"19a98ZfEezDNbCwidVigV5PAJwrR2kw4Jz":[],"19z3j2ELqbg2pR87byCCt3BCyKR7rc3q8G":[],"1A3XSmvLQvePmvm7yctsGkBMX9ZKKXLrVq":[],"1CmhFe2BN1h9jheFpJf4v39XNPj8F9U6d":[],"1DuphhHUayKzbkdvjVjf5dtjn2ACkz4zEs":[],"1E4ygSNJpWL2uPXZHBptmU2LqwZTqb1Ado":[],"1GTDSjkVc9vaaBBBGNVqTANHJBcoT5VW9z":[],"1GWqgpThAuSq3tDg6uCoLQxPXQNnU8jZ52":[],"1GhmpwqSF5cqNgdr9oJMZx8dKxPRo4pYPP":[],"1J5TTUQKhwehEACw6Jjte1E22FVrbeDmpv":[],"1JWySzjzJhsETUUcqVZHuvQLA7pfFfmesb":[],"1KQHxcy3QUHAWMHKUtJjqD9cMKXcY2RTwZ":[],"1KoxZfc2KsgovjGDxwqanbFEA76uxgYH4G":[],"1KqVEPXdpbYvEbwsZcEKkrA4A2jsgj9hYN":[],"1N16yDSYe76c5A3CoVoWAKxHeAUc8Jhf9J":[],"1Pm8JBhzUJDqeQQKrmnop1Frr4phe1jbTt":[]},"addresses":{"change":["1GhmpwqSF5cqNgdr9oJMZx8dKxPRo4pYPP","1GTDSjkVc9vaaBBBGNVqTANHJBcoT5VW9z","15FGuHvRssu1r8fCw98vrbpfc3M4xs5FAV","1A3XSmvLQvePmvm7yctsGkBMX9ZKKXLrVq","19z3j2ELqbg2pR87byCCt3BCyKR7rc3q8G","1JWySzjzJhsETUUcqVZHuvQLA7pfFfmesb"],"receiving":["14gmBxYV97mzYwWdJSJ3MTLbTHVegaKrcA","13HT1pfWctsSXVFzF76uYuVdQvcAQ2MAgB","19a98ZfEezDNbCwidVigV5PAJwrR2kw4Jz","1J5TTUQKhwehEACw6Jjte1E22FVrbeDmpv","1Pm8JBhzUJDqeQQKrmnop1Frr4phe1jbTt","13kG9WH9JqS7hyCcVL1ssLdNv4aXocQY9c","1KQHxcy3QUHAWMHKUtJjqD9cMKXcY2RTwZ","12ECgkzK6gHouKAZ7QiooYBuk1CgJLJxes","12iR43FPb5M7sw4Mcrr5y1nHKepg9EtZP1","14Tf3qiiHJXStSU4KmienAhHfHq7FHpBpz","1KqVEPXdpbYvEbwsZcEKkrA4A2jsgj9hYN","17oJzweA2gn6SDjsKgA9vUD5ocT1sSnr2Z","1E4ygSNJpWL2uPXZHBptmU2LqwZTqb1Ado","18hNcSjZzRcRP6J2bfFRxp9UfpMoC4hGTv","1KoxZfc2KsgovjGDxwqanbFEA76uxgYH4G","18n9PFxBjmKCGhd4PCDEEqYsi2CsnEfn2B","1CmhFe2BN1h9jheFpJf4v39XNPj8F9U6d","1DuphhHUayKzbkdvjVjf5dtjn2ACkz4zEs","1GWqgpThAuSq3tDg6uCoLQxPXQNnU8jZ52","1N16yDSYe76c5A3CoVoWAKxHeAUc8Jhf9J"]},"keystore":{"seed":"cereal wise two govern top pet frog nut rule sketch bundle logic","type":"bip32","xprv":"xprv9s21ZrQH143K29XjRjUs6MnDB9wXjXbJP2kG1fnRk8zjdDYWqVkQYUqaDtgZp5zPSrH5PZQJs8sU25HrUgT1WdgsPU8GbifKurtMYg37d4v","xpub":"xpub661MyMwAqRbcEdcCXm1sTViwjBn28zK9kFfrp4C3JUXiW1sfP34f6HA45B9yr7EH5XGzWuTfMTdqpt9XPrVQVUdgiYb5NW9m8ij1FSZgGBF"},"pruned_txo":{},"seed_type":"standard","seed_version":13,"stored_height":-1,"transactions":{},"tx_fees":{},"txi":{},"txo":{},"use_encryption":false,"verified_tx3":{},"wallet_type":"standard","winpos-qt":[619,310,840,405]}''' - storage = WalletStorage(self.wallet_path) - db = WalletDB(wallet_str, storage=storage, upgrade=True) + storage = WalletStorage(path=self.wallet_path) + storage.set_data(wallet_str) + db = WalletDB(storage.get_stored_dict()) wallet = Wallet(db, config=self.config) wallet.check_password(None) @@ -424,14 +443,15 @@ async def test_update_password_of_standard_wallet_oldseed(self): async def test_update_password_with_app_restarts(self): wallet_str = '{"addr_history":{"1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr":[],"15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA":[],"1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6":[]},"addresses":{"change":[],"receiving":["1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr","1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6","15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA"]},"keystore":{"keypairs":{"0344b1588589958b0bcab03435061539e9bcf54677c104904044e4f8901f4ebdf5":"L2sED74axVXC4H8szBJ4rQJrkfem7UMc6usLCPUoEWxDCFGUaGUM","0389508c13999d08ffae0f434a085f4185922d64765c0bff2f66e36ad7f745cc5f":"L3Gi6EQLvYw8gEEUckmqawkevfj9s8hxoQDFveQJGZHTfyWnbk1U","04575f52b82f159fa649d2a4c353eb7435f30206f0a6cb9674fbd659f45082c37d559ffd19bea9c0d3b7dcc07a7b79f4cffb76026d5d4dff35341efe99056e22d2":"5JyVyXU1LiRXATvRTQvR9Kp8Rx1X84j2x49iGkjSsXipydtByUq"},"type":"imported"},"pruned_txo":{},"seed_version":13,"stored_height":-1,"transactions":{},"tx_fees":{},"txi":{},"txo":{},"use_encryption":false,"verified_tx3":{},"wallet_type":"standard","winpos-qt":[100,100,840,405]}' storage = WalletStorage(self.wallet_path) - db = WalletDB(wallet_str, storage=storage, upgrade=True) + storage.set_data(wallet_str) + db = WalletDB(storage.get_stored_dict()) wallet = Wallet(db, config=self.config) await wallet.stop() storage = WalletStorage(self.wallet_path) # if storage.is_encrypted(): # storage.decrypt(password) - db = WalletDB(storage.read(), storage=storage, upgrade=True) + db = WalletDB(storage.get_stored_dict()) wallet = Wallet(db, config=self.config) wallet.check_password(None) diff --git a/tests/test_wallet_vertical.py b/tests/test_wallet_vertical.py index c0e26e0ddabd..ab220f611f54 100644 --- a/tests/test_wallet_vertical.py +++ b/tests/test_wallet_vertical.py @@ -7,8 +7,9 @@ import copy from electrum import bitcoin, keystore, bip32, slip39, wallet +from electrum.json_db import JsonDB from electrum.wallet_db import WalletDB -from electrum.storage import WalletStorage +from electrum.stored_dict import WalletStorage, BaseDB from electrum import SimpleConfig from electrum import util from electrum.address_synchronizer import TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE @@ -55,7 +56,7 @@ def check_xpub_keystore_sanity(cls, test_obj, ks): @classmethod def create_standard_wallet(cls, ks, *, config: SimpleConfig, gap_limit=None, gap_limit_for_change=None): - db = WalletDB('', storage=None, upgrade=True) + db = WalletDB(JsonDB('').get_stored_dict()) db.put('keystore', ks.dump()) db.put('gap_limit', gap_limit or cls.gap_limit) db.put('gap_limit_for_change', gap_limit_for_change or cls.gap_limit_for_change) @@ -65,7 +66,7 @@ def create_standard_wallet(cls, ks, *, config: SimpleConfig, gap_limit=None, gap @classmethod def create_imported_wallet(cls, *, config: SimpleConfig, privkeys: bool): - db = WalletDB('', storage=None, upgrade=True) + db = WalletDB(JsonDB('').get_stored_dict()) if privkeys: k = keystore.Imported_KeyStore({}) db.put('keystore', k.dump()) @@ -79,12 +80,14 @@ def create_multisig_wallet( multisig_type: str, *, config: SimpleConfig, - storage: WalletStorage | None = None, + storage: BaseDB | None = None, gap_limit=None, gap_limit_for_change=None, ): """Creates a multisig wallet.""" - db = WalletDB('', storage=storage, upgrade=False) + if storage is None: + storage = WalletStorage(None) + db = WalletDB(storage.get_stored_dict()) for i, ks in enumerate(keystores): cosigner_index = i + 1 db.put('x%d' % cosigner_index, ks.dump()) diff --git a/tests/test_wizard.py b/tests/test_wizard.py index 7c5b740a8611..43a6b4945817 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -12,7 +12,7 @@ from electrum import slip39 from electrum.bip32 import KeyOriginInfo from electrum import keystore -from electrum.storage import WalletStorage +from electrum.stored_dict import WalletStorage from . import ElectrumTestCase from .test_wallet_vertical import UNICODE_HORROR, WalletIntegrityHelper