Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 170 additions & 11 deletions electrum/lnmsg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
from types import MappingProxyType
from collections import OrderedDict

import electrum_ecc as ecc

from . import bitcoin
from .lnutil import OnionFailureCodeMetaFlag
from .util import chunks


class FailedToParseMsg(Exception):
msg_type_int: Optional[int] = None
msg_type_name: Optional[str] = None


class UnknownMsgType(FailedToParseMsg): pass
class UnknownOptionalMsgType(UnknownMsgType): pass
class UnknownMandatoryMsgType(UnknownMsgType): pass

class MalformedMsg(FailedToParseMsg): pass
class UnknownMsgFieldType(MalformedMsg): pass
class UnexpectedEndOfStream(MalformedMsg): pass
Expand All @@ -24,6 +28,7 @@ class UnknownMandatoryTLVRecordType(MalformedMsg): pass
class MsgTrailingGarbage(MalformedMsg): pass
class MsgInvalidFieldOrder(MalformedMsg): pass
class UnexpectedFieldSizeForEncoder(MalformedMsg): pass
class MsgInvalidSignature(MalformedMsg): pass


def _num_remaining_bytes_to_read(fd: io.BytesIO) -> int:
Expand Down Expand Up @@ -94,7 +99,7 @@ def _read_primitive_field(
fd: io.BytesIO,
field_type: str,
count: Union[int, str]
) -> Union[bytes, int]:
) -> Union[bytes, int, str]:
if not fd:
raise Exception()
if isinstance(count, int):
Expand Down Expand Up @@ -150,6 +155,8 @@ def _read_primitive_field(
type_len = 32
elif field_type == 'signature':
type_len = 64
elif field_type == 'bip340sig':
type_len = 64
elif field_type == 'point':
type_len = 33
elif field_type == 'short_channel_id':
Expand All @@ -166,6 +173,9 @@ def _read_primitive_field(
if len(buf) != type_len:
raise UnexpectedEndOfStream()
return buf
elif field_type == 'utf8':
if count != '...':
raise Exception(f"utf8 fields can only have unbounded count")

if count == "...":
total_len = -1 # read all
Expand All @@ -177,6 +187,20 @@ def _read_primitive_field(
buf = fd.read(total_len)
if total_len >= 0 and len(buf) != total_len:
raise UnexpectedEndOfStream()

if field_type == 'utf8':
try:
return buf.decode('utf-8')
except UnicodeDecodeError as e:
raise MalformedMsg(f'invalid utf-8: {buf.hex()}') from e

if field_type == 'point':
for point in chunks(buf, type_len):
try:
ecc.ECPubkey(b=point)
except ecc.keys.InvalidECPointException as e:
raise MalformedMsg(f"invalid point: {point.hex()}") from e

return buf


Expand All @@ -186,7 +210,7 @@ def _write_primitive_field(
fd: io.BytesIO,
field_type: str,
count: Union[int, str],
value: Union[bytes, int]
value: Union[bytes, int, str]
) -> None:
if not fd:
raise Exception()
Expand Down Expand Up @@ -246,6 +270,8 @@ def _write_primitive_field(
type_len = 32
elif field_type == 'signature':
type_len = 64
elif field_type == 'bip340sig':
type_len = 64
elif field_type == 'point':
type_len = 33
elif field_type == 'short_channel_id':
Expand All @@ -258,6 +284,10 @@ def _write_primitive_field(
type_len = 33 # point
else:
raise Exception(f"invalid sciddir_or_pubkey, prefix byte not in range 0-3")
elif field_type == 'utf8':
if count != '...':
raise Exception(f"utf8 fields can only have unbounded count")
value = value.encode('utf-8')
total_len = -1
if count != "...":
if type_len is None:
Expand All @@ -274,12 +304,16 @@ def _write_primitive_field(
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?")


def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes]:
def _read_tlv_record(*, fd: io.BytesIO) -> Tuple[int, bytes, bytes]:
if not fd: raise Exception()
pos_start = fd.tell()
tlv_type = _read_primitive_field(fd=fd, field_type="bigsize", count=1)
tlv_len = _read_primitive_field(fd=fd, field_type="bigsize", count=1)
tlv_val = _read_primitive_field(fd=fd, field_type="byte", count=tlv_len)
return tlv_type, tlv_val
pos_end = fd.tell()
fd.seek(pos_start)
rawbytes = fd.read(pos_end - pos_start)
return tlv_type, tlv_val, rawbytes


def _write_tlv_record(*, fd: io.BytesIO, tlv_type: int, tlv_val: bytes) -> None:
Expand Down Expand Up @@ -321,6 +355,46 @@ def _parse_msgtype_intvalue_for_onion_wire(value: str) -> int:
return msg_type_int


def _tlv_merkle_root(tlvs: List[Sequence[bytes]]) -> bytes:
first_tlv = None
tlv_merkle_nodes = []

for tlvt, tlv in tlvs:
if first_tlv is None:
first_tlv = tlv
tlv_val = tlv
tlv_record_type = write_bigsize_int(tlvt)
merkle_leaf_hash = bitcoin.bip340_tagged_hash(b'LnLeaf', tlv_val)
merkle_nonce = bitcoin.bip340_tagged_hash(b'LnNonce' + first_tlv, tlv_record_type)

# ascending order
msg = merkle_leaf_hash + merkle_nonce if merkle_leaf_hash < merkle_nonce else merkle_nonce + merkle_leaf_hash
merkle_node_hash = bitcoin.bip340_tagged_hash(b'LnBranch', msg)

tlv_merkle_nodes.append(merkle_node_hash)

while len(tlv_merkle_nodes) > 1:
target = []
for chunk in chunks(tlv_merkle_nodes, 2):
if len(chunk) == 1:
target.append(chunk[0])
else:
msg = chunk[0] + chunk[1] if chunk[0] < chunk[1] else chunk[1] + chunk[0]
merkle_node_hash = bitcoin.bip340_tagged_hash(b'LnBranch', msg)
target.append(merkle_node_hash)
tlv_merkle_nodes = target

return tlv_merkle_nodes[0]


def _is_bolt12_signature_tlv_type(tlv_type: int) -> bool:
"""
bolt12: each form is signed using one or more *signature TLV elements*: TLV
types 240 through 1000 (inclusive)
"""
return tlv_type in range(240, 1001)


class LNSerializer:

def __init__(self, *, name: str = 'peer_wire'):
Expand All @@ -331,6 +405,7 @@ def __init__(self, *, name: str = 'peer_wire'):
self.in_tlv_stream_get_tlv_record_scheme_from_type = {} # type: Dict[str, Dict[int, List[Sequence[str]]]]
self.in_tlv_stream_get_record_type_from_name = {} # type: Dict[str, Dict[str, int]]
self.in_tlv_stream_get_record_name_from_type = {} # type: Dict[str, Dict[int, str]]
self.in_tlv_stream_signature_tlv_records = {} # type: Dict[str, Dict[int, str]]

self.subtypes = {} # type: Dict[str, Dict[str, Sequence[str]]]

Expand Down Expand Up @@ -393,7 +468,17 @@ def __init__(self, *, name: str = 'peer_wire'):
assert fieldname not in self.subtypes[subtypename], f"duplicate field definition for {fieldname} for subtype {subtypename}"
self.subtypes[subtypename][fieldname] = tuple(row)
else:
pass # TODO
pass # TODO: raise?

for stream_name, scheme_map in self.in_tlv_stream_get_tlv_record_scheme_from_type.items():
sig_records = {}
for tlv_type, scheme in scheme_map.items():
for row in scheme:
if row[0] == 'tlvdata' and row[4] == 'bip340sig':
assert _is_bolt12_signature_tlv_type(tlv_type), f"bip340sig field outside bolt 12 range: {stream_name=} {tlv_type=}"
sig_records[tlv_type] = row[3] # e.g. 240: 'sig'
break
self.in_tlv_stream_signature_tlv_records[stream_name] = sig_records # e.g. 'invoice_request': {240: 'sig'}

def write_field(
self,
Expand Down Expand Up @@ -495,14 +580,42 @@ def read_field(
count=subtype_field_count)
parsedlist.append(parsed)

# fd might contain more bytes, but we got passed a count. break when we have 'count' items.
# (e.g. nested complex types)
if isinstance(count, int) and len(parsedlist) == count:
break

return parsedlist if count == '...' or count > 1 else parsedlist[0]

def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) -> None:
def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, signing_key: Optional[bytes] = None, **kwargs) -> None:
sign_over_tlvs = []
sig_tlv_type, sig_tlv_record_name = None, None
sig_records = self.in_tlv_stream_signature_tlv_records.get(tlv_stream_name, {})
if signing_key is not None:
sig_tlv_types = list(sig_records.keys()) # e.g. [240] ('signature')
if len(sig_tlv_types) != 1:
raise NotImplementedError
sig_tlv_type = sig_tlv_types[0]
sig_tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][sig_tlv_type]
assert sig_tlv_record_name not in kwargs, f"pass either existing {sig_tlv_record_name} or signing_key"

scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
for tlv_record_type, scheme in scheme_map.items(): # note: tlv_record_type is monotonically increasing
tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]

is_signature_record = tlv_record_type in sig_records
if tlv_record_name not in kwargs:
continue
# skip record_name if not in kwargs, unless we need to generate it
if not is_signature_record or signing_key is None:
continue
# calculate signature over previously serialized tlv records
# and store in kwargs for inclusion in tlv stream
merkle_root = _tlv_merkle_root(sign_over_tlvs)
priv = ecc.ECPrivkey(signing_key)
tag = b'lightning' + tlv_stream_name.encode('ascii') + sig_tlv_record_name.encode('ascii')
signature = priv.schnorr_sign(bitcoin.bip340_tagged_hash(tag, merkle_root))
kwargs[tlv_record_name] = {sig_records[tlv_record_type]: signature} # e.g. 'signature': {'sig': <sig over root>}

with io.BytesIO() as tlv_record_fd:
for row in scheme:
if row[0] == "tlvtype":
Expand All @@ -525,14 +638,29 @@ def write_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str, **kwargs) ->
value=field_value)
else:
raise Exception(f"unexpected row in scheme: {row!r}")
_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_record_fd.getvalue())

def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str, Dict[str, Any]]:
tlv_val = tlv_record_fd.getvalue()

_write_tlv_record(fd=fd, tlv_type=tlv_record_type, tlv_val=tlv_val)

# signature TLVs are excluded from the bolt 12 merkle root
if signing_key is not None and not is_signature_record:
with io.BytesIO() as tlvfd:
_write_tlv_record(fd=tlvfd, tlv_type=tlv_record_type, tlv_val=tlv_val)
sign_over_tlvs.append((tlv_record_type, tlvfd.getvalue()))

def read_tlv_stream(self, *,
fd: io.BytesIO,
tlv_stream_name: str,
signing_key_path: Optional[Sequence[str]] = None) -> Dict[str, Dict[str, Any]]:
sign_over_tlvs = []
signature_record = None # type: Optional[tuple[int, str]] # (tlv_type, tlv_record_name)
parsed = {} # type: Dict[str, Dict[str, Any]]
scheme_map = self.in_tlv_stream_get_tlv_record_scheme_from_type[tlv_stream_name]
sig_records = self.in_tlv_stream_signature_tlv_records.get(tlv_stream_name, {})
last_seen_tlv_record_type = -1 # type: int
while _num_remaining_bytes_to_read(fd) > 0:
tlv_record_type, tlv_record_val = _read_tlv_record(fd=fd)
tlv_record_type, tlv_record_val, rawbytes = _read_tlv_record(fd=fd)
if not (tlv_record_type > last_seen_tlv_record_type):
raise MsgInvalidFieldOrder(f"TLV records must be monotonically increasing by type. "
f"cur: {tlv_record_type}. prev: {last_seen_tlv_record_type}")
Expand All @@ -545,8 +673,20 @@ def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str,
raise UnknownMandatoryTLVRecordType(f"{tlv_stream_name}/{tlv_record_type}") from None
else:
# unknown "odd" type: skip it
if signing_key_path and not _is_bolt12_signature_tlv_type(tlv_record_type):
sign_over_tlvs.append((tlv_record_type, rawbytes))
continue
tlv_record_name = self.in_tlv_stream_get_record_name_from_type[tlv_stream_name][tlv_record_type]

# collect tlvs for deferred signature check
if signing_key_path:
if tlv_record_type in sig_records:
if signature_record is not None:
raise MalformedMsg(f"multiple signatures in {tlv_stream_name=} not supported")
signature_record = (tlv_record_type, tlv_record_name)
else:
sign_over_tlvs.append((tlv_record_type, rawbytes))

parsed[tlv_record_name] = {}
with io.BytesIO(tlv_record_val) as tlv_record_fd:
for row in scheme:
Expand All @@ -573,6 +713,25 @@ def read_tlv_stream(self, *, fd: io.BytesIO, tlv_stream_name: str) -> Dict[str,
raise Exception(f"unexpected row in scheme: {row!r}")
if _num_remaining_bytes_to_read(tlv_record_fd) > 0:
raise MsgTrailingGarbage(f"TLV record ({tlv_stream_name}/{tlv_record_name}) has extra trailing garbage")

if signing_key_path:
if signature_record is None:
raise MalformedMsg(f"expected signature in {tlv_stream_name}")
sig_tlv_type, sig_tlv_record_name = signature_record
merkle_root = _tlv_merkle_root(sign_over_tlvs)
signing_pubkey = parsed
for key in signing_key_path: # walk signing_key_path
signing_pubkey = signing_pubkey[key]
assert isinstance(signing_pubkey, bytes)
sig_field_name = sig_records[sig_tlv_type]
sig_bytes = parsed[sig_tlv_record_name][sig_field_name]
if not isinstance(sig_bytes, bytes) or len(sig_bytes) != 64:
raise MsgInvalidSignature(f"invalid signature data in {tlv_stream_name}/{sig_tlv_record_name}: {sig_bytes=}")
tag = b'lightning' + tlv_stream_name.encode('ascii') + sig_tlv_record_name.encode('ascii')
tagh = bitcoin.bip340_tagged_hash(tag, merkle_root)
if not ecc.ECPubkey(signing_pubkey).schnorr_verify(sig_bytes, tagh):
raise MsgInvalidSignature(f"invalid signature in {'.'.join(signing_key_path)}")

return parsed

def encode_msg(self, msg_type: str, **kwargs) -> bytes:
Expand Down
Loading