diff --git a/plugins/tenet/integration/api/ida_api.py b/plugins/tenet/integration/api/ida_api.py index 2a62187..ba207f4 100644 --- a/plugins/tenet/integration/api/ida_api.py +++ b/plugins/tenet/integration/api/ida_api.py @@ -22,6 +22,7 @@ import ida_diskio import ida_kernwin import ida_segment +import ida_ida from .api import DisassemblerCoreAPI, DisassemblerContextAPI from ...util.qt import * @@ -198,9 +199,9 @@ def get_processor_type(self): pass def is_64bit(self): - inf = ida_idaapi.get_inf_structure() - #target_filetype = inf.filetype - return inf.is_64bit() + # The get_inf_structure() function is deprecated in newer IDA versions. + # The 'inf' object is now directly accessible via ida_ida. + return ida_ida.inf_is_64bit() def is_call_insn(self, address): insn = ida_ua.insn_t() diff --git a/plugins/tenet/ui/hex_view.py b/plugins/tenet/ui/hex_view.py index 89bfb75..a1ab4dc 100644 --- a/plugins/tenet/ui/hex_view.py +++ b/plugins/tenet/ui/hex_view.py @@ -197,7 +197,7 @@ def _refresh_painting_metrics(self): self._width_aux = (self.model.num_bytes_per_line * self._char_width) + self._char_width * 2 # enforce a minimum view width, to ensure all text stays visible - self.setMinimumWidth(self._pos_aux + self._width_aux) + self.setMinimumWidth(int(self._pos_aux + self._width_aux)) def full_size(self): """ @@ -206,7 +206,7 @@ def full_size(self): if not self.model.data: return QtCore.QSize(0, 0) - width = self._pos_aux + (self.model.num_bytes_per_line * self._char_width) + width = int(self._pos_aux + (self.model.num_bytes_per_line * self._char_width)) height = len(self.model.data) // self.model.num_bytes_per_line if len(self.model.data) % self.model.num_bytes_per_line: height += 1 @@ -694,17 +694,17 @@ def paintEvent(self, event): painter.fillRect(event.rect(), self._palette.hex_data_bg) # paint address area background - address_area_rect = QtCore.QRect(0, event.rect().top(), self._width_addr, self.height()) + address_area_rect = QtCore.QRect(0, event.rect().top(), int(self._width_addr), self.height()) painter.fillRect(address_area_rect, self._palette.hex_address_bg) # paint line between address area and hex area painter.setPen(self._palette.hex_separator) - painter.drawLine(self._width_addr, event.rect().top(), self._width_addr, self.height()) + painter.drawLine(int(self._width_addr), event.rect().top(), int(self._width_addr), self.height()) # paint line between hex area and auxillary area line_pos = self._pos_aux painter.setPen(self._palette.hex_separator) - painter.drawLine(line_pos, event.rect().top(), line_pos, self.height()) + painter.drawLine(int(line_pos), event.rect().top(), int(line_pos), self.height()) for line_idx in range(0, self.num_lines_visible): self._paint_line(painter, line_idx) @@ -735,7 +735,7 @@ def _paint_line(self, painter, line_idx): pack_len = self.model.pointer_size address_fmt = '%016X' if pack_len == 8 else '%08X' address_text = address_fmt % address - painter.drawText(self._pos_addr, y, address_text) + painter.drawText(int(self._pos_addr), y, address_text) self._default_color = self._palette.hex_text_fg if address < self.model.fade_address: @@ -775,7 +775,7 @@ def _paint_line(self, painter, line_idx): else: ch = chr(ch) - painter.drawText(x_pos_aux, y, ch) + painter.drawText(int(x_pos_aux), y, ch) x_pos_aux += self._char_width def _paint_hex_item(self, painter, byte_idx, stop_idx, x, y): @@ -960,7 +960,7 @@ def _paint_text(self, painter, byte_idx, padding, x, y): painter.setPen(QtCore.Qt.NoPen) painter.setBrush(bg_color) - painter.drawRect(x_bg, y_bg, width, height) + painter.drawRect(int(x_bg), int(y_bg), int(width), int(height)) painter.setPen(fg_color) @@ -968,7 +968,7 @@ def _paint_text(self, painter, byte_idx, padding, x, y): # paint text # - painter.drawText(x, y, text) + painter.drawText(int(x), y, text) def _paint_magic(self, painter, byte_idx, stop_idx, x, y): """ @@ -1006,7 +1006,7 @@ def _paint_magic(self, painter, byte_idx, stop_idx, x, y): # draw the pointer pointer_str = ("0x%08X " % value).rjust(num_chars) - painter.drawText(x, y, pointer_str) + painter.drawText(int(x), y, pointer_str) x += num_chars * self._char_width return (byte_idx + self.model.pointer_size, x, y) diff --git a/plugins/tenet/ui/reg_view.py b/plugins/tenet/ui/reg_view.py index 6f5ce2d..42c4732 100644 --- a/plugins/tenet/ui/reg_view.py +++ b/plugins/tenet/ui/reg_view.py @@ -109,7 +109,7 @@ def __init__(self, controller, model, parent=None): self.setFocusPolicy(QtCore.Qt.StrongFocus) self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) - self.setMinimumWidth(self._reg_pos[0] + self._default_width) + self.setMinimumWidth(int(self._reg_pos[0] + self._default_width)) self.setMouseTracking(True) self._init_ctx_menu() @@ -118,8 +118,8 @@ def __init__(self, controller, model, parent=None): self.model.registers_changed(self.refresh) def sizeHint(self): - width = self._default_width - height = (len(self._reg_fields) + 2) * self._char_height # +2 for line break before IP, and after IP + width = int(self._default_width) + height = int((len(self._reg_fields) + 2) * self._char_height) # +2 for line break before IP, and after IP return QtCore.QSize(width, height) def _init_ctx_menu(self): @@ -152,7 +152,7 @@ def _init_reg_positions(self): fm = QtGui.QFontMetricsF(self.font()) name_size = fm.boundingRect('X'*common_count).size() value_size = fm.boundingRect('0' * (self.model.arch.POINTER_SIZE * 2)).size() - arrow_size = (int(value_size.height() * 0.70) & 0xFE) + 1 + arrow_size = (int(value_size.height() * 0.70) | 1) # pre-compute the position of each register in the window for reg_name in regs: @@ -162,24 +162,24 @@ def _init_reg_positions(self): if reg_name == self.model.arch.IP: y += self._char_height - name_rect = QtCore.QRect(0, 0, name_size.width(), name_size.height()) - name_rect.moveBottomLeft(QtCore.QPoint(name_x, y)) + name_rect = QtCore.QRect(0, 0, int(name_size.width()), int(name_size.height())) + name_rect.moveBottomLeft(QtCore.QPoint(int(name_x), int(y))) - prev_rect = QtCore.QRect(0, 0, arrow_size, arrow_size) - next_rect = QtCore.QRect(0, 0, arrow_size, arrow_size) + prev_rect = QtCore.QRect(0, 0, int(arrow_size), int(arrow_size)) + next_rect = QtCore.QRect(0, 0, int(arrow_size), int(arrow_size)) arrow_rects = [prev_rect, next_rect] prev_x = name_x + name_size.width() + self._char_width prev_rect.moveCenter(name_rect.center()) - prev_rect.moveLeft(prev_x) + prev_rect.moveLeft(int(prev_x)) value_x = prev_x + prev_rect.width() + self._char_width - value_rect = QtCore.QRect(0, 0, value_size.width(), value_size.height()) - value_rect.moveBottomLeft(QtCore.QPoint(value_x, y)) + value_rect = QtCore.QRect(0, 0, int(value_size.width()), int(value_size.height())) + value_rect.moveBottomLeft(QtCore.QPoint(int(value_x), int(y))) next_x = value_x + value_size.width() + self._char_width next_rect.moveCenter(name_rect.center()) - next_rect.moveLeft(next_x) + next_rect.moveLeft(int(next_x)) # save the register shapes self._reg_fields[reg_name] = RegisterField(reg_name, name_rect, value_rect, arrow_rects) @@ -274,8 +274,8 @@ def full_size(self): if not self.model.registers: return QtCore.QSize(0, 0) - width = self._reg_pos[0] + self._default_width - height = len(self.model.registers) * self._char_height + width = int(self._reg_pos[0] + self._default_width) + height = int(len(self.model.registers) * self._char_height) return QtCore.QSize(width, height) diff --git a/plugins/tenet/ui/trace_view.py b/plugins/tenet/ui/trace_view.py index 7f3644f..ee05b02 100644 --- a/plugins/tenet/ui/trace_view.py +++ b/plugins/tenet/ui/trace_view.py @@ -479,7 +479,7 @@ def _idx2pos(self, idx): #assert self._cell_spacing % 2 == 0 # compute the y position of the 'first' cell - y += self._cell_spacing / 2 # pad out from top + y += self._cell_spacing // 2 # pad out from top y += self._cell_border # top border of cell # compute the y position of any given cell after the first @@ -953,7 +953,7 @@ def _draw_code_cells(self, painter): painter.setBrush(self.pctx.palette.trace_unmapped) y = self._idx2pos(idx) - painter.drawRect(x, y, w, h) + painter.drawRect(int(x), int(y), int(w), int(h)) def _draw_highlights(self): """ @@ -1008,7 +1008,7 @@ def _draw_highlights_cells(self, painter): y = self._idx2pos(idx) + self._cell_border # draw cell body - painter.drawRect(viz_x, y, viz_w, h) + painter.drawRect(int(viz_x), int(y), int(viz_w), int(h)) def _draw_highlights_trace(self, painter): """ @@ -1090,13 +1090,13 @@ def _draw_cursor(self): self._painter_cursor.setBrush(self.pctx.palette.trace_cursor_highlight) if draw_reader_cursor: - self._painter_cursor.drawRect(viz_x, cell_y, viz_w, cell_body_height) + self._painter_cursor.drawRect(int(viz_x), int(cell_y), int(viz_w), int(cell_body_height)) # cursor hover highlighting an event if self._hovered_idx != INVALID_IDX: hovered_y = self._idx2pos(self._hovered_idx) hovered_cell_y = hovered_y + self._cell_border - self._painter_cursor.drawRect(viz_x, hovered_cell_y, viz_w, cell_body_height) + self._painter_cursor.drawRect(int(viz_x), int(hovered_cell_y), int(viz_w), int(cell_body_height)) # draw the user cursor in dense/landscape mode else: @@ -1172,7 +1172,7 @@ def _draw_selection(self): h = end_y - start_y # draw the screen door / selection rect - self._painter_selection.drawRect(x, y, w, h) + self._painter_selection.drawRect(int(x), int(y), int(w), int(h)) def _draw_border(self): """ diff --git a/plugins/tenet/util/qt/waitbox.py b/plugins/tenet/util/qt/waitbox.py index c0b05f9..586dcc8 100644 --- a/plugins/tenet/util/qt/waitbox.py +++ b/plugins/tenet/util/qt/waitbox.py @@ -86,16 +86,16 @@ def _ui_layout(self): self._abort_button.clicked.connect(self._abort) v_layout.addWidget(self._abort_button) - v_layout.setSpacing(self._dpi_scale*3) + v_layout.setSpacing(int(self._dpi_scale*3)) v_layout.setContentsMargins( - self._dpi_scale*5, - self._dpi_scale, - self._dpi_scale*5, - self._dpi_scale + int(self._dpi_scale*5), + int(self._dpi_scale), + int(self._dpi_scale*5), + int(self._dpi_scale) ) # scale widget dimensions based on DPI - height = self._dpi_scale * 15 + height = int(self._dpi_scale * 15) self.setMinimumHeight(height) # compute the dialog layout diff --git a/plugins_sogen-support/tenet/__init__.py b/plugins_sogen-support/tenet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins_sogen-support/tenet/__pycache__/__init__.cpython-311.pyc b/plugins_sogen-support/tenet/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..ee7b145 Binary files /dev/null and b/plugins_sogen-support/tenet/__pycache__/__init__.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/__pycache__/breakpoints.cpython-311.pyc b/plugins_sogen-support/tenet/__pycache__/breakpoints.cpython-311.pyc new file mode 100644 index 0000000..242072d Binary files /dev/null and b/plugins_sogen-support/tenet/__pycache__/breakpoints.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/__pycache__/context.cpython-311.pyc b/plugins_sogen-support/tenet/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000..5e3de3b Binary files /dev/null and b/plugins_sogen-support/tenet/__pycache__/context.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/__pycache__/core.cpython-311.pyc b/plugins_sogen-support/tenet/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000..18bb2e8 Binary files /dev/null and b/plugins_sogen-support/tenet/__pycache__/core.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/__pycache__/hex.cpython-311.pyc b/plugins_sogen-support/tenet/__pycache__/hex.cpython-311.pyc new file mode 100644 index 0000000..aaabc44 Binary files /dev/null and b/plugins_sogen-support/tenet/__pycache__/hex.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/__pycache__/memory.cpython-311.pyc b/plugins_sogen-support/tenet/__pycache__/memory.cpython-311.pyc new file mode 100644 index 0000000..35e1052 Binary files /dev/null and b/plugins_sogen-support/tenet/__pycache__/memory.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/__pycache__/registers.cpython-311.pyc b/plugins_sogen-support/tenet/__pycache__/registers.cpython-311.pyc new file mode 100644 index 0000000..133e563 Binary files /dev/null and b/plugins_sogen-support/tenet/__pycache__/registers.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/__pycache__/stack.cpython-311.pyc b/plugins_sogen-support/tenet/__pycache__/stack.cpython-311.pyc new file mode 100644 index 0000000..f5d9867 Binary files /dev/null and b/plugins_sogen-support/tenet/__pycache__/stack.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/__pycache__/types.cpython-311.pyc b/plugins_sogen-support/tenet/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000..0578fe3 Binary files /dev/null and b/plugins_sogen-support/tenet/__pycache__/types.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/breakpoints.py b/plugins_sogen-support/tenet/breakpoints.py new file mode 100644 index 0000000..37ed00b --- /dev/null +++ b/plugins_sogen-support/tenet/breakpoints.py @@ -0,0 +1,195 @@ +import itertools + +from tenet.ui import * +from tenet.types import BreakpointType, BreakpointEvent, TraceBreakpoint +from tenet.util.misc import register_callback, notify_callback +from tenet.integration.api import DockableWindow +from tenet.integration.api import disassembler + +#------------------------------------------------------------------------------ +# breakpoints.py -- Breakpoint Controller +#------------------------------------------------------------------------------ +# +# The purpose of this file is to house the 'headless' components of the +# breakpoints window and its underlying functionality. This is split into +# a model and controller component, of a typical 'MVC' design pattern. +# +# v0.1 NOTE/TODO: err, a dedicated bp window was planned but did not quite +# make the cut for the initial release of this plugin. For that reason, +# some of this logic may be half-baked pending further work. +# +# v0.2 NOTE/TODO: Currently, the breakpoint controller/Tenet artificially +# limits usage to one execution breakpoint and one memory breakpoint at +# a time. I'll probably raise this 'limit' when a proper gui is made +# for managing and differentiating between breakpoints... +# + +class BreakpointController(object): + """ + The Breakpoint Controller (Logic) + """ + + def __init__(self, pctx): + self.pctx = pctx + self.model = BreakpointModel() + + # UI components + if QT_AVAILABLE: + self.view = BreakpointView(self, self.model) + self.dockable = DockableWindow("Trace Breakpoints", self.view) + else: + self.view = None + self.dockable = None + + # events + self._ignore_signals = False + self.pctx.core.ui_breakpoint_changed(self._ui_breakpoint_changed) + + def reset(self): + """ + Reset the breakpoint controller. + """ + self.model.reset() + + def add_breakpoint(self, address, access_type, length=1): + """ + Add a breakpoint of the given access type. + """ + if access_type == BreakpointType.EXEC: + self.add_execution_breakpoint(address, length) + elif access_type == BreakpointType.READ: + self.add_read_breakpoint(address, length) + elif access_type == BreakpointType.WRITE: + self.add_write_breakpoint(address, length) + elif access_type == BreakpointType.ACCESS: + self.add_access_breakpoint(address, length) + else: + raise ValueError("UNKNOWN ACCESS TYPE", access_type) + + def add_execution_breakpoint(self, address): + """ + Add an execution breakpoint for the given address. + """ + self.model.bp_exec[address] = TraceBreakpoint(address, BreakpointType.EXEC) + self.model._notify_breakpoints_changed() + + def add_read_breakpoint(self, address, length=1): + """ + Add a memory read breakpoint for the given address. + """ + self.model.bp_read[address] = TraceBreakpoint(address, BreakpointType.READ, length) + self.model._notify_breakpoints_changed() + + def add_write_breakpoint(self, address, length=1): + """ + Add a memory write breakpoint for the given address. + """ + self.model.bp_write[address] = TraceBreakpoint(address, BreakpointType.WRITE, length) + self.model._notify_breakpoints_changed() + + def add_access_breakpoint(self, address, length=1): + """ + Add a memory access breakpoint for the given address. + """ + self.model.bp_access[address] = TraceBreakpoint(address, BreakpointType.ACCESS, length) + self.model._notify_breakpoints_changed() + + def clear_breakpoints(self): + """ + Clear all breakpoints. + """ + self.model.bp_exec = {} + self.model.bp_read = {} + self.model.bp_write = {} + self.model.bp_access = {} + self.model._notify_breakpoints_changed() + + def clear_execution_breakpoints(self): + """ + Clear all execution breakpoints. + """ + self.model.bp_exec = {} + self.model._notify_breakpoints_changed() + + def clear_memory_breakpoints(self): + """ + Clear all memory breakpoints. + """ + self.model.bp_read = {} + self.model.bp_write = {} + self.model.bp_access = {} + self.model._notify_breakpoints_changed() + + def _ui_breakpoint_changed(self, address, event_type): + """ + Handle a breakpoint change event from the UI. + """ + if self._ignore_signals: + return + + self._delete_disassembler_breakpoints() + self.model.bp_exec = {} + + if event_type in [BreakpointEvent.ADDED, BreakpointEvent.ENABLED]: + self.add_execution_breakpoint(address) + + self.model._notify_breakpoints_changed() + + def _delete_disassembler_breakpoints(self): + """ + Remove all execution breakpoints from the disassembler UI. + """ + dctx = disassembler[self.pctx] + + self._ignore_signals = True + for address in self.model.bp_exec: + dctx.delete_breakpoint(address) + self._ignore_signals = False + +class BreakpointModel(object): + """ + The Breakpoint Model (Data) + """ + + def __init__(self): + self.reset() + + #---------------------------------------------------------------------- + # Callbacks + #---------------------------------------------------------------------- + + self._breakpoints_changed_callbacks = [] + + def reset(self): + self.bp_exec = {} + self.bp_read = {} + self.bp_write = {} + self.bp_access = {} + + @property + def memory_breakpoints(self): + """ + Return an iterable list of all memory breakpoints. + """ + bps = itertools.chain( + self.bp_read.values(), + self.bp_write.values(), + self.bp_access.values() + ) + return bps + + #---------------------------------------------------------------------- + # Callbacks + #---------------------------------------------------------------------- + + def breakpoints_changed(self, callback): + """ + Subscribe a callback for a breakpoint changed event. + """ + register_callback(self._breakpoints_changed_callbacks, callback) + + def _notify_breakpoints_changed(self): + """ + Notify listeners of a breakpoint changed event. + """ + notify_callback(self._breakpoints_changed_callbacks) diff --git a/plugins_sogen-support/tenet/context.py b/plugins_sogen-support/tenet/context.py new file mode 100644 index 0000000..340e5bd --- /dev/null +++ b/plugins_sogen-support/tenet/context.py @@ -0,0 +1,429 @@ +import os +import logging +import traceback + +from tenet.util.qt import * +from tenet.util.log import pmsg +from tenet.util.misc import is_plugin_dev + +from tenet.stack import StackController +from tenet.memory import MemoryController +from tenet.registers import RegisterController +from tenet.breakpoints import BreakpointController +from tenet.ui.trace_view import TraceDock + +from tenet.types import BreakpointType +from tenet.trace.arch import ArchAMD64, ArchX86 +from tenet.trace.reader import TraceReader +from tenet.integration.api import disassembler, DisassemblerContextAPI + +logger = logging.getLogger("Tenet.Context") + +#------------------------------------------------------------------------------ +# context.py -- Plugin Database Context +#------------------------------------------------------------------------------ +# +# The purpose of this file is to house and manage the plugin's +# disassembler database (eg, IDB/BNDB) specific runtime state. +# +# At a high level, a unique 'instance' of the plugin runtime & subsystems +# are initialized for each opened database in supported disassemblers. The +# plugin context object acts a bit like the database specific plugin core. +# +# For example, it is possible for multiple databases to be open at once +# in the Binary Ninja disassembler. Each opened database will have a +# unique plugin context object created and used to manage state, UI, +# threads/subsystems, and loaded plugin data for that database. +# +# In IDA, this is less important as you can only have one database open +# at any given time (... at least at the time of writing) but that does +# not change how this context system works under the hood. +# + +class TenetContext(object): + """ + A per-database encapsulation of the plugin components / state. + """ + + def __init__(self, core, db): + disassembler[self] = DisassemblerContextAPI(db) + self.core = core + self.db = db + + # select a trace arch based on the binary the disassmbler has loaded + if disassembler[self].is_64bit(): + self.arch = ArchAMD64() + else: + self.arch = ArchX86() + + # this will hold the trace reader when a trace has been loaded + self.reader = None + + # plugin widgets / components + self.breakpoints = BreakpointController(self) + self.trace = TraceDock(self) # TODO: port this one to MVC pattern + self.stack = StackController(self) + self.memory = MemoryController(self) + self.registers = RegisterController(self) + + # the directory to start the 'load trace file' dialog in + self._last_directory = None + + # whether the plugin subsystems have been created / started + self._started = False + + # NOTE/DEV: automatically open a test trace file when dev/testing + if is_plugin_dev(): + self._auto_launch() + + def _auto_launch(self): + """ + Automatically load a static trace file when the database has been opened. + + NOTE/DEV: this is just to make it easier to test / develop / debug the + plugin when developing it and should not be called under normal use. + """ + + def test_load(): + import ida_loader + trace_filepath = ida_loader.get_plugin_options("Tenet") + focus_window() + self.load_trace(trace_filepath) + self.show_ui() + + def dev_launch(): + self._timer = QtCore.QTimer() + self._timer.singleShot(500, test_load) # delay to let things settle + + self.core._ui_hooks.ready_to_run = dev_launch + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def palette(self): + return self.core.palette + + #------------------------------------------------------------------------- + # Setup / Teardown + #------------------------------------------------------------------------- + + def start(self): + """ + One-time initialization of the plugin subsystems. + + This will only be called when it is clear the user is attempting + to use the plugin or its functionality (eg, they click load trace). + """ + if self._started: + return + + self.palette.warmup() + self._started = True + + def terminate(self): + """ + Spin down any plugin subsystems as the context is being deleted. + + This will be called when the database or disassembler is closing. + """ + self.close_trace() + + #------------------------------------------------------------------------- + # Public API + #------------------------------------------------------------------------- + + def trace_loaded(self): + """ + Return True if a trace is loaded / active in this plugin context. + """ + return bool(self.reader) + + def load_trace(self, filepath): + """ + Load a trace from the given filepath. + + If there is a trace already loaded / in-use prior to calling this + function, it will simply be replaced by the new trace. + """ + + # + # create the trace reader. this will load the given trace file from + # disk and wrap it with a number of useful APIs for navigating the + # trace and querying information (memory, registers) from it at + # chosen states of execution + # + + self.reader = TraceReader(filepath, self.arch, disassembler[self]) + pmsg(f"Loaded trace {self.reader.trace.filepath}") + pmsg(f"- {self.reader.trace.length:,} instructions...") + + if self.reader.analysis.slide != None: + pmsg(f"- {self.reader.analysis.slide:08X} ASLR slide...") + else: + disassembler.warning("Failed to automatically detect ASLR base!\n\nSee console for more info...") + pmsg(" +------------------------------------------------------") + pmsg(" |- ERROR: Failed to detect ASLR base for this trace.") + pmsg(" | --------------------------------------- ") + pmsg(" +-+ You can 'try' rebasing the database to the correct ASLR base") + pmsg(" | if you know it, and reload the trace. Otherwise, it is possible") + pmsg(" | your trace is just... very small and Tenet was not confident") + pmsg(" | predicting an ASLR slide.") + + # + # we only hook directly into the disassembler / UI / subsytems once + # a trace is loaded. this ensures that our python handlers don't + # introduce overhead on misc disassembler callbacks when the plugin + # isn't even being used in the reversing session. + # + + self.core.hook() + + # + # attach the trace engine to the various plugin UI controllers, giving + # them the necessary access to drive the underlying trace reader + # + + self.breakpoints.reset() + self.trace.attach_reader(self.reader) + self.stack.attach_reader(self.reader) + self.memory.attach_reader(self.reader) + self.registers.attach_reader(self.reader) + + # + # connect any high level signals from the new trace reader + # + + self.reader.idx_changed(self._idx_changed) + + def close_trace(self): + """ + Close the current trace if one is active. + """ + if not self.reader: + return + + # + # unhook the disassembler, as there will be no active / loaded trace + # after this routine completes + # + + self.core.unhook() + + # + # close UI elements and reset their model / controllers + # + + self.trace.hide() + self.trace.detach_reader() + self.stack.hide() + self.stack.detach_reader() + self.memory.hide() + self.memory.detach_reader() + self.registers.hide() + self.registers.detach_reader() + + # misc / final cleanup + self.breakpoints.reset() + #self.reader.close() + + self.reader = None + + def show_ui(self): + """ + Integrate and arrange the plugin widgets into the disassembler UI. + + TODO: ehh, there really shouldn't be any disassembler-specific stuff + outside of the disassembler integration files. it doesn't really + matter much right now but this should be moved in the future. + """ + import ida_kernwin + self.registers.show(position=ida_kernwin.DP_RIGHT) + + #self.breakpoints.dockable.set_dock_position("CPU Registers", ida_kernwin.DP_BOTTOM) + #self.breakpoints.dockable.show() + + #ida_kernwin.activate_widget(ida_kernwin.find_widget("Output window"), True) + #ida_kernwin.set_dock_pos("Output window", None, ida_kernwin.DP_BOTTOM) + #ida_kernwin.set_dock_pos("IPython Console", "Output", ida_kernwin.DP_INSIDE) + + #self.memory.dockable.set_dock_position("Output window", ida_kernwin.DP_TAB | ida_kernwin.DP_BEFORE) + self.memory.show("Output window", ida_kernwin.DP_TAB | ida_kernwin.DP_BEFORE) + + #self.stack.dockable.set_dock_position("Memory View", ida_kernwin.DP_RIGHT) + self.stack.show("Memory View", ida_kernwin.DP_RIGHT) + + mw = get_qmainwindow() + mw.addToolBar(QtCore.Qt.RightToolBarArea, self.trace) + self.trace.show() + + # trigger update check + self.core.check_for_update() + + #------------------------------------------------------------------------- + # Integrated UI Event Handlers + #------------------------------------------------------------------------- + + def interactive_load_trace(self, reloading=False): + """ + Handle UI actions for loading a trace file. + """ + + # prompt the user with a file dialog to select a trace of interest + filenames = self._select_trace_file() + if not filenames: + return + + # TODO: ehh, only support loading one trace at a time right now + assert len(filenames) == 1, "Please select only one trace file to load" + disassembler.show_wait_box("Loading trace from disk...") + filepath = filenames[0] + + # attempt to load the user selected trace + try: + self.load_trace(filepath) + except: + pmsg("Failed to load trace...") + pmsg(traceback.format_exc()) + disassembler.hide_wait_box() + return + disassembler.hide_wait_box() + + # + # if we are 're-loading', we are loading over an existing trace, so + # there should already be plugin UI elements visible and active. + # + # do not attempt to show / re-position the UI elements as they may + # have been moved by the user from their default positions into + # locations that they prefer + # + + if reloading: + return + + # show the plugin UI elements, and dock its windows as appropriate + self.show_ui() + + def interactive_next_execution(self): + """ + Handle UI actions for seeking to the next execution of the selected address. + """ + address = disassembler[self].get_current_address() + rebased_address = self.reader.analysis.rebase_pointer(address) + result = self.reader.seek_to_next(rebased_address, BreakpointType.EXEC) + + # TODO: blink screen? make failure more visible... + if not result: + pmsg(f"Go to 0x{address:08x} failed, no future executions of address") + + def interactive_prev_execution(self): + """ + Handle UI actions for seeking to the previous execution of the selected address. + """ + address = disassembler[self].get_current_address() + rebased_address = self.reader.analysis.rebase_pointer(address) + result = self.reader.seek_to_prev(rebased_address, BreakpointType.EXEC) + + # TODO: blink screen? make failure more visible... + if not result: + pmsg(f"Go to 0x{address:08x} failed, no previous executions of address") + + def interactive_first_execution(self): + """ + Handle UI actions for seeking to the first execution of the selected address. + """ + address = disassembler[self].get_current_address() + rebased_address = self.reader.analysis.rebase_pointer(address) + result = self.reader.seek_to_first(rebased_address, BreakpointType.EXEC) + + # TODO: blink screen? make failure more visible... + if not result: + pmsg(f"Go to 0x{address:08x} failed, no executions of address") + + def interactive_final_execution(self): + """ + Handle UI actions for seeking to the final execution of the selected address. + """ + address = disassembler[self].get_current_address() + rebased_address = self.reader.analysis.rebase_pointer(address) + result = self.reader.seek_to_final(rebased_address, BreakpointType.EXEC) + + # TODO: blink screen? make failure more visible... + if not result: + pmsg(f"Go to 0x{address:08x} failed, no executions of address") + + def _idx_changed(self, idx): + """ + Handle a trace reader event indicating that the current IDX has changed. + + This will make the disassembler track with the PC/IP of the trace reader. + """ + dctx = disassembler[self] + + # + # get a 'rebased' version of the current instruction pointer, which + # should map to the disassembler / open database if it is a code + # address that is known + # + + bin_address = self.reader.rebased_ip + + # + # if the code address is in a library / other unknown area that + # cannot be renedered by the disassembler, then resolve the last + # known trace 'address' within the database + # + + if not dctx.is_mapped(bin_address): + last_good_idx = self.reader.analysis.get_prev_mapped_idx(idx) + if last_good_idx == -1: + return # navigation is just not gonna happen... + + # fetch the last instruction pointer to fall within the trace + last_good_trace_address = self.reader.get_ip(last_good_idx) + + # convert the trace-based instruction pointer to one that maps to the disassembler + bin_address = self.reader.analysis.rebase_pointer(last_good_trace_address) + + # navigate the disassembler to a 'suitable' address based on the trace idx + dctx.navigate(bin_address) + disassembler.refresh_views() + + def _select_trace_file(self): + """ + Prompt a file selection dialog, returning file selections. + + This will save & reuses the last known directory for subsequent calls. + """ + + if not self._last_directory: + self._last_directory = disassembler[self].get_database_directory() + + # create & configure a Qt File Dialog for immediate use + file_dialog = QtWidgets.QFileDialog( + None, + 'Open trace file', + self._last_directory, + 'All Files (*.*)' + ) + file_dialog.setFileMode(QtWidgets.QFileDialog.ExistingFiles) + + # prompt the user with the file dialog, and await filename(s) + filenames, _ = file_dialog.getOpenFileNames() + + # + # remember the last directory we were in (parsed from a selected file) + # for the next time the user comes to load trace files + # + + if filenames: + self._last_directory = os.path.dirname(filenames[0]) + os.sep + + # log the captured (selected) filenames from the dialog + logger.debug("Captured filenames from file dialog:") + for name in filenames: + logger.debug(" - %s" % name) + + # return the captured filenames + return filenames \ No newline at end of file diff --git a/plugins_sogen-support/tenet/core.py b/plugins_sogen-support/tenet/core.py new file mode 100644 index 0000000..861900d --- /dev/null +++ b/plugins_sogen-support/tenet/core.py @@ -0,0 +1,256 @@ +import abc +import logging + +from tenet.util.log import pmsg +from tenet.ui.palette import PluginPalette +from tenet.util.update import check_for_update +from tenet.integration.api import disassembler + +logger = logging.getLogger("Tenet.Core") + +#------------------------------------------------------------------------------ +# core.py -- Plugin Core +#------------------------------------------------------------------------------ +# +# The purpose of this file is to define a specification required by the +# plugin to integrate and load under a given disassembler. +# +# This is technically the 'lowest' level layer of the plugin, as it is +# loaded / unloaded directly by the disassembler. This means that there +# should be no database or user-specific data loaded into this layer. +# +# Supporting additional disassemblers will require one to subclass this +# abstract core as part of a disassembler-specific integration layer. +# + +class TenetCore(object): + """ + The disassembler-wide plugin core. + """ + __metaclass__ = abc.ABCMeta + + #-------------------------------------------------------------------------- + # Plugin Metadata + #-------------------------------------------------------------------------- + + PLUGIN_NAME = "Tenet" + PLUGIN_VERSION = "0.2.0" + PLUGIN_AUTHORS = "Markus Gaasedelen" + PLUGIN_DATE = "2021" + + #-------------------------------------------------------------------------- + # Initialization / Teardown + #-------------------------------------------------------------------------- + + def load(self): + """ + Load the plugin, and register universal UI actions with the disassembler. + """ + self.contexts = {} + self._update_checked = False + + # the plugin color palette + self.palette = PluginPalette() + self.palette.theme_changed(self.refresh_theme) + + # integrate plugin UI to disassembler + self._install_ui() + + # all done, mark the core as loaded + self.loaded = True + + # print plugin banner + pmsg(f"Loaded v{self.PLUGIN_VERSION} - (c) {self.PLUGIN_AUTHORS} - {self.PLUGIN_DATE}") + logger.info("Successfully loaded plugin") + + def unload(self): + """ + Unload the plugin, and remove any UI integrations. + """ + if not self.loaded: + return + + pmsg("Unloading %s..." % self.PLUGIN_NAME) + + # mark the core as 'unloaded' and teardown its components + self.loaded = False + + # remove UI integrations + self._uninstall_ui() + + # spin down any active contexts (stop threads, cleanup qt state, etc) + for pctx in self.contexts.values(): + pctx.terminate() + self.contexts = {} + + # all done + logger.info("-"*75) + logger.info("Plugin terminated") + + @abc.abstractmethod + def hook(self): + """ + Install disassmbler-specific hooks. + """ + pass + + @abc.abstractmethod + def unhook(self): + """ + Remove disassmbler-specific hooks. + """ + pass + + #-------------------------------------------------------------------------- + # Disassembler / Database Context Selector + #-------------------------------------------------------------------------- + + @abc.abstractmethod + def get_context(self, db, startup=True): + """ + Get the plugin context object for the given database / session. + """ + pass + + #-------------------------------------------------------------------------- + # UI Integration + #-------------------------------------------------------------------------- + + def _install_ui(self): + """ + Initialize & integrate all plugin UI elements. + """ + self._install_load_trace() + self._install_next_execution() + self._install_prev_execution() + self._install_first_execution() + self._install_final_execution() + + def _uninstall_ui(self): + """ + Cleanup & remove all plugin UI integrations. + """ + self._uninstall_load_trace() + self._uninstall_next_execution() + self._uninstall_prev_execution() + self._uninstall_first_execution() + self._uninstall_final_execution() + + @abc.abstractmethod + def _install_load_trace(self): + """ + Install the 'File->Load->Tenet trace file...' menu entry. + """ + pass + + @abc.abstractmethod + def _install_next_execution(self): + """ + Install the right click 'Go to next execution' menu entry. + """ + pass + + @abc.abstractmethod + def _install_prev_execution(self): + """ + Install the right click 'Go to previous execution' menu entry. + """ + pass + + @abc.abstractmethod + def _install_first_execution(self): + """ + Install the right click 'Go to first execution' menu entry. + """ + pass + + @abc.abstractmethod + def _install_final_execution(self): + """ + Install the right click 'Go to final execution' menu entry. + """ + pass + + @abc.abstractmethod + def _uninstall_load_trace(self): + """ + Remove the 'File->Load file->Tenet trace file...' menu entry. + """ + pass + + @abc.abstractmethod + def _uninstall_next_execution(self): + """ + Remove the right click 'Go to next execution' menu entry. + """ + pass + + @abc.abstractmethod + def _uninstall_prev_execution(self): + """ + Remove the right click 'Go to previous execution' menu entry. + """ + pass + + @abc.abstractmethod + def _uninstall_first_execution(self): + """ + Remove the right click 'Go to first execution' menu entry. + """ + pass + + @abc.abstractmethod + def _uninstall_final_execution(self): + """ + Remove the right click 'Go to final execution' menu entry. + """ + pass + + #-------------------------------------------------------------------------- + # UI Event Handlers + #-------------------------------------------------------------------------- + + def _interactive_load_trace(self, db): + pctx = self.get_context(db) + pctx.interactive_load_trace() + + def _interactive_first_execution(self, db): + pctx = self.get_context(db) + pctx.interactive_first_execution() + + def _interactive_final_execution(self, db): + pctx = self.get_context(db) + pctx.interactive_final_execution() + + def _interactive_next_execution(self, db): + pctx = self.get_context(db) + pctx.interactive_next_execution() + + def _interactive_prev_execution(self, db): + pctx = self.get_context(db) + pctx.interactive_prev_execution() + + #-------------------------------------------------------------------------- + # Core Actions + #-------------------------------------------------------------------------- + + def refresh_theme(self): + """ + Refresh UI facing elements to reflect the current theme. + """ + for pctx in self.contexts.values(): + pass # TODO + + def check_for_update(self): + """ + Check if there is an update available for the plugin. + """ + if self._update_checked: + return + + # wrap the callback (a popup) to ensure it gets called from the UI + callback = disassembler.execute_ui(disassembler.warning) + + # kick off the async update check + check_for_update(self.PLUGIN_VERSION, callback) + self._update_checked = True diff --git a/plugins_sogen-support/tenet/hex.py b/plugins_sogen-support/tenet/hex.py new file mode 100644 index 0000000..bd8b173 --- /dev/null +++ b/plugins_sogen-support/tenet/hex.py @@ -0,0 +1,305 @@ +from tenet.ui import * +from tenet.types import * +from tenet.util.qt.util import copy_to_clipboard +from tenet.integration.api import DockableWindow + +#------------------------------------------------------------------------------ +# hex.py -- Hex Dump Controller +#------------------------------------------------------------------------------ +# +# The purpose of this file is to house the 'headless' components of a +# basic hex dump window and its underlying functionality. This is split +# into a model and controller component, of a typical 'MVC' design pattern. +# +# This provides much of the core logic behind both the memory and stack +# views used by the plugin. +# + +class HexController(object): + """ + A generalized controller for Hex View based window. + """ + + def __init__(self, pctx): + self.pctx = pctx + self.model = HexModel(pctx) + self.reader = None + + # UI components + self.view = None + self.dockable = None + self._title = "" + + # signals + self._ignore_signals = False + pctx.breakpoints.model.breakpoints_changed(self._breakpoints_changed) + + def show(self, target=None, position=0): + """ + Make the window attached to this controller visible. + """ + + # if there is no Qt (eg, our UI framework...) then there is no UI + if not QT_AVAILABLE: + return + + # the UI has already been created, and is also visible. nothing to do + if (self.dockable and self.dockable.visible): + return + + # + # if the UI has not yet been created, or has been previously closed + # then we are free to create new UI elements to take the place of + # anything that once was + + self.view = HexView(self, self.model) + new_dockable = DockableWindow(self._title, self.view) + + # + # if there is a reference to a left over dockable window (e.g, from a + # previous close of this window type) steal its dock positon so we can + # hopefully take the same place as the old one + # + + if self.dockable: + new_dockable.copy_dock_position(self.dockable) + elif (target or position): + new_dockable.set_dock_position(target, position) + + # make the dockable/widget visible + self.dockable = new_dockable + self.dockable.show() + + def hide(self): + """ + Hide the window attached to this controller. + """ + + # if there is no view/dockable, then there's nothing to try and hide + if not(self.view and self.dockable): + return + + # hide the dockable, and drop references to the widgets + self.dockable.hide() + self.view = None + self.dockable = None + + def attach_reader(self, reader): + """ + Attach a trace reader to this controller. + """ + self.reader = reader + self.model.pointer_size = reader.arch.POINTER_SIZE + + # attach trace reader signals to this controller / window + reader.idx_changed(self._idx_changed) + + # + # directly call our event handler quick with the current idx since + # it's the first time we're seeing this. this ensures that our widget + # will accurately reflect the current state of the reader + # + + self._idx_changed(reader.idx) + + def detach_reader(self): + """ + Detach the trace reader from this controller. + """ + self.reader = None + self.model.reset() + + def navigate(self, address): + """ + Navigate the hex view to a given address. + """ + if address < 0: + address = 0 + + last_visible_address = address + self.model.data_size + if last_visible_address > 0xFFFFFFFFFFFFFFFF: + last_visible_address = 0xFFFFFFFFFFFFFFFF + + self.model.address = address + + #self.reset_selection(0) + self.refresh_memory() + + def set_data_size(self, num_bytes): + """ + Change the number of bytes to be held / displayed by the viewer. + """ + self.model.data_size = num_bytes + self.refresh_memory() + + def copy_selection(self, start_address, end_address): + """ + Copy the selected range of bytes to the system clipboard. + """ + assert end_address > start_address + if not self.reader: + return '' + + # fetch memory for the selected region + num_bytes = end_address - start_address + memory = self.reader.get_memory(start_address, num_bytes) + + # dump bytes to hex + output = [] + for i in range(num_bytes): + if memory.mask[i] == 0xFF: + output.append("%02X" % memory.data[i]) + else: + output.append("??") + + byte_string = ' '.join(output) + copy_to_clipboard(byte_string) + + return byte_string + + def pin_memory(self, address, access_type=BreakpointType.ACCESS, length=1): + """ + Pin a region of memory. + """ + self._ignore_signals = True + self.pctx.breakpoints.clear_memory_breakpoints() + self.pctx.breakpoints.add_breakpoint(address, access_type, length) + self._ignore_signals = False + + def refresh_memory(self): + """ + Refresh the visible memory. + """ + if not self.reader: + self.model.data = None + self.model.mask = None + return + + memory = self.reader.get_memory(self.model.address, self.model.data_size) + + self.model.data = memory.data + self.model.mask = memory.mask + self.model.delta = self.reader.delta + + if self.view: + self.view.refresh() + + def set_fade_threshold(self, address): + """ + Change the threshold address that the view will begin to 'fade' its contents. + + This is used to 'fade' the unallocated region of the stack, for example. + """ + self.model.fade_address = address + + #------------------------------------------------------------------------- + # Callbacks + #------------------------------------------------------------------------- + + def _idx_changed(self, idx): + """ + The trace reader position has been changed. + """ + self.refresh_memory() + + def _breakpoints_changed(self): + """ + Handle breakpoints changed event. + """ + if not self.view: + return + + if self._ignore_signals: + return + + self.view.refresh() + +class HexModel(object): + """ + A generalized model for Hex View based window. + """ + + def __init__(self, pctx): + self._pctx = pctx + + # how the hex (data) and auxillary text should be displayed + self._hex_format = HexType.BYTE + self._aux_format = AuxType.ASCII + + # view settings + self._num_bytes_per_line = 16 + + # initialize the remaining model parameters + self.reset() + + def reset(self): + """ + Reset the model to a clean state. + """ + + # the 'cached' data to be displayed by the hex view + self.data = None + self.mask = None + self.data_size = 0 + self.delta = None + + self.address = 0 + self.fade_address = 0 + + # pinned memory / breakpoint selections + self._pinned_selections = [] + + #---------------------------------------------------------------------- + # Properties + #---------------------------------------------------------------------- + + @property + def memory_breakpoints(self): + """ + Return the set of active memory breakpoints. + """ + return self._pctx.breakpoints.model.memory_breakpoints + + @property + def num_bytes_per_line(self): + """ + Return the number of bytes that should be displayed per line. + """ + return self._num_bytes_per_line + + @num_bytes_per_line.setter + def num_bytes_per_line(self, width): + """ + Set the number of bytes to be displayed per line. + """ + + if width < 1: + raise ValueError("Invalid bytes per line value (must be > 0)") + + if width % HEX_TYPE_WIDTH[self._hex_format]: + raise ValueError("Bytes per line must be a multiple of display format type") + + self._num_bytes_per_line = width + #self._refresh_view_settings() + + @property + def hex_format(self): + return self._hex_format + + @hex_format.setter + def hex_format(self, value): + if value == self._hex_format: + return + self._hex_format = value + #self.refresh() + + @property + def aux_format(self): + return self._aux_format + + @aux_format.setter + def aux_format(self, value): + if value == self._aux_format: + return + self._aux_format = value + #self.refresh() \ No newline at end of file diff --git a/plugins_sogen-support/tenet/integration/__pycache__/ida_integration.cpython-311.pyc b/plugins_sogen-support/tenet/integration/__pycache__/ida_integration.cpython-311.pyc new file mode 100644 index 0000000..0a62a2a Binary files /dev/null and b/plugins_sogen-support/tenet/integration/__pycache__/ida_integration.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/integration/__pycache__/ida_loader.cpython-311.pyc b/plugins_sogen-support/tenet/integration/__pycache__/ida_loader.cpython-311.pyc new file mode 100644 index 0000000..c5239af Binary files /dev/null and b/plugins_sogen-support/tenet/integration/__pycache__/ida_loader.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/integration/api/__init__.py b/plugins_sogen-support/tenet/integration/api/__init__.py new file mode 100644 index 0000000..392938e --- /dev/null +++ b/plugins_sogen-support/tenet/integration/api/__init__.py @@ -0,0 +1,40 @@ +#-------------------------------------------------------------------------- +# Disassembler API Selector +#-------------------------------------------------------------------------- +# +# this file will select and load the shimmed disassembler API for the +# appropriate (current) disassembler platform. +# +# see api.py for more details regarding this API shim layer +# + +disassembler = None + +#-------------------------------------------------------------------------- +# IDA API Shim +#-------------------------------------------------------------------------- + +if disassembler == None: + from .ida_api import IDACoreAPI, IDAContextAPI, DockableWindow + disassembler = IDACoreAPI() + DisassemblerContextAPI = IDAContextAPI + +##-------------------------------------------------------------------------- +## Binary Ninja API Shim +##-------------------------------------------------------------------------- +# +#if disassembler == None: +# try: +# from .binja_api import BinjaCoreAPI, BinjaContextAPI +# disassembler = BinjaCoreAPI() +# DisassemblerContextAPI = BinjaContextAPI +# except ImportError: +# pass + +#-------------------------------------------------------------------------- +# Unknown Disassembler +#-------------------------------------------------------------------------- + +if disassembler == None: + raise NotImplementedError("Unknown or unsupported disassembler!") + diff --git a/plugins_sogen-support/tenet/integration/api/__pycache__/__init__.cpython-311.pyc b/plugins_sogen-support/tenet/integration/api/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..55631d2 Binary files /dev/null and b/plugins_sogen-support/tenet/integration/api/__pycache__/__init__.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/integration/api/__pycache__/api.cpython-311.pyc b/plugins_sogen-support/tenet/integration/api/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000..6fb0a55 Binary files /dev/null and b/plugins_sogen-support/tenet/integration/api/__pycache__/api.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/integration/api/__pycache__/ida_api.cpython-311.pyc b/plugins_sogen-support/tenet/integration/api/__pycache__/ida_api.cpython-311.pyc new file mode 100644 index 0000000..cc5f0f5 Binary files /dev/null and b/plugins_sogen-support/tenet/integration/api/__pycache__/ida_api.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/integration/api/api.py b/plugins_sogen-support/tenet/integration/api/api.py new file mode 100644 index 0000000..37d7f8b --- /dev/null +++ b/plugins_sogen-support/tenet/integration/api/api.py @@ -0,0 +1,326 @@ +import abc +import logging + +from ...util.qt import * + +logger = logging.getLogger("Tenet.Integration.API") + +#------------------------------------------------------------------------------ +# Disassembler API +#------------------------------------------------------------------------------ +# +# the purpose of this file is to provide an abstraction layer for the more +# generic disassembler APIs required by the plugin codebase. we strive to +# use (or extend) this API for the bulk of our disassembler operations, +# making the plugin as disassembler-agnostic as possible. +# +# by subclassing the templated classes below, the plugin can support other +# disassembler plaforms relatively easily. at the moment, implementing these +# subclasses is ~50% of the work that is required to add support for this +# plugin to any given interactive disassembler. +# +# TODO: technically, a bunch of definitions are missing from this file +# that are present in the IDA integration implementation. these will +# need to be copied over to here to better define the disassembler API +# dependencies required by this plugin +# + +class DisassemblerCoreAPI(object): + """ + An abstract implementation of the core disassembler APIs. + """ + __metaclass__ = abc.ABCMeta + + # the name of the disassembler framework, eg 'IDA' or 'BINJA' + NAME = NotImplemented + + @abc.abstractmethod + def __init__(self): + self._ctxs = {} + + # required version fields + self._version_major = NotImplemented + self._version_minor = NotImplemented + self._version_patch = NotImplemented + + if not self.headless and QT_AVAILABLE: + self._waitbox = WaitBox("Please wait...") + else: + self._waitbox = None + + def __delitem__(self, key): + del self._ctxs[key] + + def __getitem__(self, key): + return self._ctxs[key] + + def __setitem__(self, key, value): + self._ctxs[key] = value + + #-------------------------------------------------------------------------- + # Properties + #-------------------------------------------------------------------------- + + def version_major(self): + """ + Return the major version number of the disassembler framework. + """ + assert self._version_major != NotImplemented + return self._version_major + + def version_minor(self): + """ + Return the minor version number of the disassembler framework. + """ + assert self._version_patch != NotImplemented + return self._version_patch + + def version_patch(self): + """ + Return the patch version number of the disassembler framework. + """ + assert self._version_patch != NotImplemented + return self._version_patch + + @abc.abstractproperty + def headless(self): + """ + Return a bool indicating if the disassembler is running without a GUI. + """ + pass + + #-------------------------------------------------------------------------- + # Synchronization Decorators + #-------------------------------------------------------------------------- + + @staticmethod + def execute_read(function): + """ + Thread-safe function decorator to READ from the disassembler database. + """ + raise NotImplementedError("execute_read() has not been implemented") + + @staticmethod + def execute_write(function): + """ + Thread-safe function decorator to WRITE to the disassembler database. + """ + raise NotImplementedError("execute_write() has not been implemented") + + @staticmethod + def execute_ui(function): + """ + Thread-safe function decorator to perform UI disassembler actions. + + This function is generally used for executing UI (Qt) events from + a background thread. as such, your implementation is expected to + transfer execution to the main application thread where it is safe to + perform Qt actions. + """ + raise NotImplementedError("execute_ui() has not been implemented") + + #-------------------------------------------------------------------------- + # Disassembler Universal APIs + #-------------------------------------------------------------------------- + + @abc.abstractmethod + def get_disassembler_user_directory(self): + """ + Return the 'user' directory for the disassembler. + """ + pass + + @abc.abstractmethod + def get_disassembly_background_color(self): + """ + Return the background color of the disassembly text view. + """ + pass + + @abc.abstractmethod + def is_msg_inited(self): + """ + Return a bool if the disassembler output window is initialized. + """ + pass + + def warning(self, text): + """ + Display a warning dialog box with the given text. + """ + msgbox = QtWidgets.QMessageBox() + before = msgbox.sizeHint().width() + msgbox.setIcon(QtWidgets.QMessageBox.Critical) + after = msgbox.sizeHint().width() + icon_width = after - before + + msgbox.setWindowTitle("Tenet Warning") + msgbox.setText(text) + + font = msgbox.font() + fm = QtGui.QFontMetricsF(font) + text_width = fm.size(0, text).width() + + # don't ask... + spacer = QtWidgets.QSpacerItem(int(text_width*1.1 + icon_width), 0, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) + layout = msgbox.layout() + layout.addItem(spacer, layout.rowCount(), 0, 1, layout.columnCount()) + msgbox.setLayout(layout) + + # show the dialog + msgbox.exec_() + + @abc.abstractmethod + def message(self, function_address, new_name): + """ + Print a message to the disassembler console. + """ + pass + + #-------------------------------------------------------------------------- + # UI APIs + #-------------------------------------------------------------------------- + + @abc.abstractmethod + def create_dockable(self, dockable_name, widget): + """ + Creates a dockable widget. + """ + pass + + #------------------------------------------------------------------------------ + # WaitBox API + #------------------------------------------------------------------------------ + + def show_wait_box(self, text, modal=True): + """ + Show the disassembler universal WaitBox. + """ + assert QT_AVAILABLE, "This function can only be used in a Qt runtime" + self._waitbox.set_text(text) + self._waitbox.show(modal) + + def hide_wait_box(self): + """ + Hide the disassembler universal WaitBox. + """ + assert QT_AVAILABLE, "This function can only be used in a Qt runtime" + self._waitbox.hide() + + def replace_wait_box(self, text): + """ + Replace the text in the disassembler universal WaitBox. + """ + assert QT_AVAILABLE, "This function can only be used in a Qt runtime" + self._waitbox.set_text(text) + +#------------------------------------------------------------------------------ +# Disassembler Contextual API +#------------------------------------------------------------------------------ + +class DisassemblerContextAPI(object): + """ + An abstract implementation of database/contextual disassembler APIs. + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def __init__(self, dctx): + self.dctx = dctx + + #-------------------------------------------------------------------------- + # Properties + #-------------------------------------------------------------------------- + + @abc.abstractproperty + def busy(self): + """ + Return a bool indicating if the disassembler is busy / processing. + """ + pass + + #-------------------------------------------------------------------------- + # API Shims + #-------------------------------------------------------------------------- + + def is_64bit(self): + """ + Return True if the loaded processor module is 64bit. + """ + pass + + @abc.abstractmethod + def get_current_address(self): + """ + Return the current cursor address in the open database. + """ + pass + + @abc.abstractmethod + def get_database_directory(self): + """ + Return the directory for the open database. + """ + pass + + @abc.abstractmethod + def get_function_addresses(self): + """ + Return all defined function addresses in the open database. + """ + pass + + @abc.abstractmethod + def get_function_name_at(self, address): + """ + Return the name of the function at the given address. + + This is generally the user-facing/demangled name seen throughout the + disassembler and is probably what you want to use for almost everything. + """ + pass + + @abc.abstractmethod + def get_function_raw_name_at(self, address): + """ + Return the raw (eg, unmangled) name of the function at the given address. + + On the backend, most disassemblers store what is called the 'true' or + 'raw' (eg, unmangled) function name. + """ + pass + + @abc.abstractmethod + def get_imagebase(self): + """ + Return the base address of the open database. + """ + pass + + @abc.abstractmethod + def get_root_filename(self): + """ + Return the root executable (file) name used to generate the database. + """ + pass + + @abc.abstractmethod + def navigate(self, address, function_address=None): + """ + Jump the disassembler UI to the given address. + """ + pass + + @abc.abstractmethod + def navigate_to_function(self, function_address, address): + """ + Jump the disassembler UI to the given address, within a function. + """ + pass + + @abc.abstractmethod + def set_function_name_at(self, function_address, new_name): + """ + Set the function name at given address. + """ + pass \ No newline at end of file diff --git a/plugins_sogen-support/tenet/integration/api/ida_api.py b/plugins_sogen-support/tenet/integration/api/ida_api.py new file mode 100644 index 0000000..ba207f4 --- /dev/null +++ b/plugins_sogen-support/tenet/integration/api/ida_api.py @@ -0,0 +1,576 @@ +import logging +import functools + +# +# TODO: should probably cleanup / document this file a bit better. +# +# it's worth noting that most of this is based on the same shim layer +# used by lighthouse +# + +import ida_ua +import ida_dbg +import ida_idp +import ida_pro +import ida_auto +import ida_nalt +import ida_name +import ida_xref +import idautils +import ida_bytes +import ida_idaapi +import ida_diskio +import ida_kernwin +import ida_segment +import ida_ida + +from .api import DisassemblerCoreAPI, DisassemblerContextAPI +from ...util.qt import * +from ...util.misc import is_mainthread + +logger = logging.getLogger("Tenet.API.IDA") + +#------------------------------------------------------------------------------ +# Utils +#------------------------------------------------------------------------------ + +def execute_sync(function, sync_type): + """ + Synchronize with the disassembler for safe database access. + + Modified from https://github.com/vrtadmin/FIRST-plugin-ida + """ + + @functools.wraps(function) + def wrapper(*args, **kwargs): + output = [None] + + # + # this inline function definition is technically what will execute + # in the context of the main thread. we use this thunk to capture + # any output the function may want to return to the user. + # + + def thunk(): + output[0] = function(*args, **kwargs) + return 1 + + if is_mainthread(): + thunk() + else: + ida_kernwin.execute_sync(thunk, sync_type) + + # return the output of the synchronized execution + return output[0] + return wrapper + +#------------------------------------------------------------------------------ +# Disassembler Core API (universal) +#------------------------------------------------------------------------------ + +class IDACoreAPI(DisassemblerCoreAPI): + NAME = "IDA" + + def __init__(self): + super(IDACoreAPI, self).__init__() + self._dockable_factory = {} + self._init_version() + + def _init_version(self): + + # retrieve IDA's version # + disassembler_version = ida_kernwin.get_kernel_version() + major, minor = map(int, disassembler_version.split(".")) + + # save the version number components for later use + self._version_major = major + self._version_minor = minor + self._version_patch = 0 + + #-------------------------------------------------------------------------- + # Properties + #-------------------------------------------------------------------------- + + @property + def headless(self): + return ida_kernwin.cvar.batch + + #-------------------------------------------------------------------------- + # Synchronization Decorators + #-------------------------------------------------------------------------- + + @staticmethod + def execute_read(function): + return execute_sync(function, ida_kernwin.MFF_READ) + + @staticmethod + def execute_write(function): + return execute_sync(function, ida_kernwin.MFF_WRITE) + + @staticmethod + def execute_ui(function): + return execute_sync(function, ida_kernwin.MFF_FAST) + + #-------------------------------------------------------------------------- + # API Shims + #-------------------------------------------------------------------------- + + def get_disassembler_user_directory(self): + return ida_diskio.get_user_idadir() + + def refresh_views(self): + ida_kernwin.refresh_idaview_anyway() + + def get_disassembly_background_color(self): + """ + Get the background color of the IDA disassembly view. + """ + + # create a donor IDA 'code viewer' + viewer = ida_kernwin.simplecustviewer_t() + viewer.Create("Colors") + + # get the viewer's qt widget + viewer_twidget = viewer.GetWidget() + viewer_widget = ida_kernwin.PluginForm.TWidgetToPyQtWidget(viewer_twidget) + + # fetch the background color property + #viewer.Show() # TODO: re-enable! + color = viewer_widget.property("line_bg_default") + + # destroy the view as we no longer need it + #viewer.Close() + + # return the color + return color + + def is_msg_inited(self): + return ida_kernwin.is_msg_inited() + + @execute_ui.__func__ + def warning(self, text): + super(IDACoreAPI, self).warning(text) + + @execute_ui.__func__ + def message(self, message): + print(message) + + #-------------------------------------------------------------------------- + # UI API Shims + #-------------------------------------------------------------------------- + + def create_dockable(self, window_title, widget): + + # create a dockable widget, and save a reference to it for later use + twidget = ida_kernwin.create_empty_widget(window_title) + + # cast the IDA 'twidget' as a Qt widget for use + dockable = ida_kernwin.PluginForm.TWidgetToPyQtWidget(twidget) + layout = dockable.layout() + layout.addWidget(widget) + + # return the dockable QtWidget / container + return dockable + +#------------------------------------------------------------------------------ +# Disassembler Context API (database-specific) +#------------------------------------------------------------------------------ + +class IDAContextAPI(DisassemblerContextAPI): + + def __init__(self, dctx): + super(IDAContextAPI, self).__init__(dctx) + + @property + def busy(self): + return not(ida_auto.auto_is_ok()) + + #-------------------------------------------------------------------------- + # API Shims + #-------------------------------------------------------------------------- + + @IDACoreAPI.execute_read + def get_current_address(self): + return ida_kernwin.get_screen_ea() + + def get_processor_type(self): + ## get the target arch, PLFM_386, PLFM_ARM, etc # TODO + #arch = idaapi.ph_get_id() + pass + + def is_64bit(self): + # The get_inf_structure() function is deprecated in newer IDA versions. + # The 'inf' object is now directly accessible via ida_ida. + return ida_ida.inf_is_64bit() + + def is_call_insn(self, address): + insn = ida_ua.insn_t() + if ida_ua.decode_insn(insn, address) and ida_idp.is_call_insn(insn): + return True + return False + + def get_instruction_addresses(self): + """ + Return all instruction addresses from the executable. + """ + instruction_addresses = [] + + for seg_address in idautils.Segments(): + + # fetch code segments + seg = ida_segment.getseg(seg_address) + if seg.sclass != ida_segment.SEG_CODE: + continue + + current_address = seg_address + end_address = seg.end_ea + + # save the address of each instruction in the segment + while current_address < end_address: + current_address = ida_bytes.next_head(current_address, end_address) + if ida_bytes.is_code(ida_bytes.get_flags(current_address)): + instruction_addresses.append(current_address) + + # print(f"Seg {seg.start_ea:08X} --> {seg.end_ea:08X} CODE") + #print(f" -- {len(instruction_addresses):,} instructions found") + + return instruction_addresses + + def is_mapped(self, address): + return ida_bytes.is_mapped(address) + + def get_next_insn(self, address): + + xb = ida_xref.xrefblk_t() + ok = xb.first_from(address, ida_xref.XREF_ALL) + + while ok and xb.iscode: + if xb.type == ida_xref.fl_F: + return xb.to + ok = xb.next_from() + + return -1 + + def get_prev_insn(self, address): + + xb = ida_xref.xrefblk_t() + ok = xb.first_to(address, ida_xref.XREF_ALL) + + while ok and xb.iscode: + if xb.type == ida_xref.fl_F: + return xb.frm + ok = xb.next_to() + + return -1 + + def get_database_directory(self): + return idautils.GetIdbDir() + + def get_function_addresses(self): + return list(idautils.Functions()) + + def get_function_name_at(self, address): + return ida_name.get_short_name(address) + + def get_function_raw_name_at(self, function_address): + return ida_name.get_name(function_address) + + def get_imagebase(self): + return ida_nalt.get_imagebase() + + def get_root_filename(self): + return ida_nalt.get_root_filename() + + def navigate(self, address): + + # TODO fetch active view? or most recent one? i'm lazy for now... + widget = ida_kernwin.find_widget("IDA View-A") + + # + # this call can both navigate to an arbitrary address, and keep + # the cursor position 'static' within the window at an (x,y) + # text position + # + # TODO: I think it's kind of tricky to figure out the 'center' line of + # the disassembly window navigation, so for now we'll just make a + # navigation call always center around line 20... + # + + CENTER_AROUND_LINE_INDEX = 20 + + if widget: + return ida_kernwin.ea_viewer_history_push_and_jump(widget, address, 0, CENTER_AROUND_LINE_INDEX, 0) + + # ehh, whatever.. just let IDA navigate to yolo + else: + return ida_kernwin.jumpto(address) + + def navigate_to_function(self, function_address, address): + return self.navigate(address) + + def set_function_name_at(self, function_address, new_name): + ida_name.set_name(function_address, new_name, ida_name.SN_NOWARN) + + def set_breakpoint(self, address): + ida_dbg.add_bpt(address) + + def delete_breakpoint(self, address): + ida_dbg.del_bpt(address) + +#------------------------------------------------------------------------------ +# HexRays Util +#------------------------------------------------------------------------------ + +def hexrays_available(): + """ + Return True if an IDA decompiler is loaded and available for use. + """ + try: + import ida_hexrays + return ida_hexrays.init_hexrays_plugin() + except ImportError: + return False + +def map_line2citem(decompilation_text): + """ + Map decompilation line numbers to citems. + + This function allows us to build a relationship between citems in the + ctree and specific lines in the hexrays decompilation text. + + Output: + + +- line2citem: + | a map keyed with line numbers, holding sets of citem indexes + | + | eg: { int(line_number): sets(citem_indexes), ... } + ' + + """ + line2citem = {} + + # + # it turns out that citem indexes are actually stored inline with the + # decompilation text output, hidden behind COLOR_ADDR tokens. + # + # here we pass each line of raw decompilation text to our crappy lexer, + # extracting any COLOR_ADDR tokens as citem indexes + # + + for line_number in range(decompilation_text.size()): + line_text = decompilation_text[line_number].line + line2citem[line_number] = lex_citem_indexes(line_text) + + return line2citem + +def map_line2node(cfunc, metadata, line2citem): + """ + Map decompilation line numbers to node (basic blocks) addresses. + + This function allows us to build a relationship between graph nodes + (basic blocks) and specific lines in the hexrays decompilation text. + + Output: + + +- line2node: + | a map keyed with line numbers, holding sets of node addresses + | + | eg: { int(line_number): set(nodes), ... } + ' + + """ + line2node = {} + treeitems = cfunc.treeitems + function_address = cfunc.entry_ea + + # + # prior to this function, a line2citem map was built to tell us which + # citems reside on any given line of text in the decompilation output. + # + # now, we walk through this line2citem map one 'line_number' at a time in + # an effort to resolve the set of graph nodes associated with its citems. + # + + for line_number, citem_indexes in line2citem.items(): + nodes = set() + + # + # we are at the level of a single line (line_number). we now consume + # its set of citems (citem_indexes) and attempt to identify explicit + # graph nodes they claim to be sourced from (by their reported EA) + # + + for index in citem_indexes: + + # get the code address of the given citem + try: + item = treeitems[index] + address = item.ea + + # apparently this is a thing on IDA 6.95 + except IndexError as e: + continue + + # find the graph node (eg, basic block) that generated this citem + node = metadata.get_node(address) + + # address not mapped to a node... weird. continue to the next citem + if not node: + #logger.warning("Failed to map node to basic block") + continue + + # + # we made it this far, so we must have found a node that contains + # this citem. save the computed node_id to the list of known + # nodes we have associated with this line of text + # + + nodes.add(node.address) + + # + # finally, save the completed list of node ids as identified for this + # line of decompilation text to the line2node map that we are building + # + + line2node[line_number] = nodes + + # all done, return the computed map + return line2node + +def lex_citem_indexes(line): + """ + Lex all ctree item indexes from a given line of text. + + The HexRays decompiler output contains invisible text tokens that can + be used to attribute spans of text to the ctree items that produced them. + + This function will simply scrape and return a list of all the these + tokens (COLOR_ADDR) which contain item indexes into the ctree. + + """ + i = 0 + indexes = [] + line_length = len(line) + + # lex COLOR_ADDR tokens from the line of text + while i < line_length: + + # does this character mark the start of a new COLOR_* token? + if line[i] == idaapi.COLOR_ON: + + # yes, so move past the COLOR_ON byte + i += 1 + + # is this sequence for a COLOR_ADDR? + if ord(line[i]) == idaapi.COLOR_ADDR: + + # yes, so move past the COLOR_ADDR byte + i += 1 + + # + # A COLOR_ADDR token is followed by either 8, or 16 characters + # (a hex encoded number) that represents an address/pointer. + # in this context, it is actually the index number of a citem + # + + citem_index = int(line[i:i+idaapi.COLOR_ADDR_SIZE], 16) + i += idaapi.COLOR_ADDR_SIZE + + # save the extracted citem index + indexes.append(citem_index) + + # skip to the next iteration as i has moved + continue + + # nothing we care about happened, keep lexing forward + i += 1 + + # return all the citem indexes extracted from this line of text + return indexes + +class DockableWindow(ida_kernwin.PluginForm): + + def __init__(self, title, widget): + super(DockableWindow, self).__init__() + self.title = title + self.widget = widget + + self.visible = False + self._dock_position = None + self._dock_target = None + + if ida_pro.IDA_SDK_VERSION < 760: + self.__dock_filter = IDADockSizeHack() + + def OnCreate(self, form): + #print("Creating", self.title) + self.parent = self.FormToPyQtWidget(form) + + layout = QtWidgets.QVBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.widget) + self.parent.setLayout(layout) + + if ida_pro.IDA_SDK_VERSION < 760: + self.__dock_size_hack() + + def OnClose(self, foo): + self.visible = False + #print("Closing", self.title) + + def __dock_size_hack(self): + if self.widget.minimumWidth() == 0: + return + self.widget.min_width = self.widget.minimumWidth() + self.widget.max_width = self.widget.maximumWidth() + self.widget.setMinimumWidth(self.widget.min_width // 2) + self.widget.setMaximumWidth(self.widget.min_width // 2) + self.widget.installEventFilter(self.__dock_filter) + + def show(self): + dock_position = self._dock_position + + if ida_pro.IDA_SDK_VERSION < 760: + WOPN_SZHINT = 0x200 + + # create the dockable widget, without actually showing it + self.Show(self.title, options=ida_kernwin.PluginForm.WOPN_CREATE_ONLY) + + # use some kludge to display our widget, and enforce the use of its sizehint + ida_widget = self.GetWidget() + ida_kernwin.display_widget(ida_widget, WOPN_SZHINT) + self.visible = True + + # no hax required for IDA 7.6 and newer + else: + self.Show(self.title) + self.visible = True + dock_position |= ida_kernwin.DP_SZHINT + + # move the window to a given location if specified + if dock_position is not None: + ida_kernwin.set_dock_pos(self.title, self._dock_target, dock_position) + + def hide(self): + self.Close(1) + + def set_dock_position(self, dest_ctrl=None, position=0): + self._dock_target = dest_ctrl + self._dock_position = position + + if not self.visible: + return + + ida_kernwin.set_dock_pos(self.title, dest_ctrl, position) + + def copy_dock_position(self, other): + self._dock_target = other._dock_target + self._dock_position = other._dock_position + +class IDADockSizeHack(QtCore.QObject): + def eventFilter(self, obj, event): + if event.type() == QtCore.QEvent.WindowActivate: + obj.setMinimumWidth(obj.min_width) + obj.setMaximumWidth(obj.max_width) + obj.removeEventFilter(self) + return False \ No newline at end of file diff --git a/plugins_sogen-support/tenet/integration/ida_integration.py b/plugins_sogen-support/tenet/integration/ida_integration.py new file mode 100644 index 0000000..b865995 --- /dev/null +++ b/plugins_sogen-support/tenet/integration/ida_integration.py @@ -0,0 +1,525 @@ +import ctypes +import logging + +# +# TODO: should probably cleanup / document this file a bit better +# + +import ida_dbg +import ida_bytes +import ida_idaapi +import ida_kernwin + +from tenet.core import TenetCore +from tenet.types import BreakpointEvent +from tenet.context import TenetContext +from tenet.util.misc import register_callback, notify_callback, is_plugin_dev +from tenet.util.qt import * + +logger = logging.getLogger("Tenet.IDA.Integration") + +IDA_GLOBAL_CTX = "blah this value doesn't matter" + +#------------------------------------------------------------------------------ +# IDA UI Integration +#------------------------------------------------------------------------------ + +class TenetIDA(TenetCore): + """ + The plugin integration layer IDA Pro. + """ + + def __init__(self): + + # + # icons + # + + self._icon_id_file = ida_idaapi.BADADDR + self._icon_id_next_execution = ida_idaapi.BADADDR + self._icon_id_prev_execution = ida_idaapi.BADADDR + + # + # event hooks + # + + self._hooked = False + + self._ui_hooks = UIHooks() + self._ui_hooks.get_lines_rendering_info = self._render_lines + self._ui_hooks.finish_populating_widget_popup = self._popup_hook + + self._dbg_hooks = DbgHooks() + self._dbg_hooks.dbg_bpt_changed = self._breakpoint_changed_hook + + # + # we should always hook the UI early in dev mode as we will use UI + # events to auto-launch a trace + # + + if is_plugin_dev(): + self._ui_hooks.hook() + + # + # callbacks + # + + self._ui_breakpoint_changed_callbacks = [] + + # + # run disassembler-agnostic core initalization + # + + super(TenetIDA, self).__init__() + + def hook(self): + if self._hooked: + return + self._hooked = True + self._ui_hooks.hook() + self._dbg_hooks.hook() + + def unhook(self): + if not self._hooked: + return + self._hooked = False + self._ui_hooks.unhook() + self._dbg_hooks.unhook() + + def get_context(self, dctx, startup=True): + """ + Get the plugin context for a given database. + + NOTE: since IDA can only have one binary / IDB open at a time, the + dctx (database context) should always be IDA_GLOBAL_CTX. + """ + assert dctx is IDA_GLOBAL_CTX + self.palette.warmup() + + # + # there should only ever be 'one' disassembler / IDB context at any + # time for IDA. but if one does not exist yet, that means this is the + # first time the user has interacted with the plugin for this session + # + + if dctx not in self.contexts: + + # create a new 'plugin context' representing this IDB + pctx = TenetContext(self, dctx) + if startup: + pctx.start() + + # save the created ctx for future calls + self.contexts[dctx] = pctx + + # return the plugin context object for this IDB + return self.contexts[dctx] + + #-------------------------------------------------------------------------- + # IDA Actions + #-------------------------------------------------------------------------- + + ACTION_LOAD_TRACE = "tenet:load_trace" + ACTION_FIRST_EXECUTION = "tenet:first_execution" + ACTION_FINAL_EXECUTION = "tenet:final_execution" + ACTION_NEXT_EXECUTION = "tenet:next_execution" + ACTION_PREV_EXECUTION = "tenet:prev_execution" + + def _install_load_trace(self): + + # TODO: create a custom IDA icon + #icon_path = plugin_resource(os.path.join("icons", "load.png")) + #icon_data = open(icon_path, "rb").read() + #self._icon_id_file = ida_kernwin.load_custom_icon(data=icon_data) + + # describe a custom IDA UI action + action_desc = ida_kernwin.action_desc_t( + self.ACTION_LOAD_TRACE, # The action name + "~T~enet trace file...", # The action text + IDACtxEntry(self._interactive_load_trace), # The action handler + None, # Optional: action shortcut + "Load a Tenet trace file", # Optional: tooltip + -1 # Optional: the action icon + ) + + # register the action with IDA + result = ida_kernwin.register_action(action_desc) + assert result, f"Failed to register '{action_desc.name}' action with IDA" + + # attach the action to the File-> dropdown menu + result = ida_kernwin.attach_action_to_menu( + "File/Load file/", # Relative path of where to add the action + self.ACTION_LOAD_TRACE, # The action ID (see above) + ida_kernwin.SETMENU_APP # We want to append the action after ^ + ) + assert result, f"Failed action attach {action_desc.name}" + + logger.info(f"Installed the '{action_desc.name}' menu entry") + + def _install_next_execution(self): + + icon_data = self.palette.gen_arrow_icon(self.palette.arrow_next, 0) + self._icon_id_next_execution = ida_kernwin.load_custom_icon(data=icon_data) + + # describe a custom IDA UI action + action_desc = ida_kernwin.action_desc_t( + self.ACTION_NEXT_EXECUTION, # The action name + "Go to next execution", # The action text + IDACtxEntry(self._interactive_next_execution), # The action handler + None, # Optional: action shortcut + "Go to the next execution of the current address", # Optional: tooltip + self._icon_id_next_execution # Optional: the action icon + ) + + # register the action with IDA + result = ida_kernwin.register_action(action_desc) + assert result, f"Failed to register '{action_desc.name}' action with IDA" + logger.info(f"Installed the '{action_desc.name}' menu entry") + + def _install_prev_execution(self): + + icon_data = self.palette.gen_arrow_icon(self.palette.arrow_prev, 180.0) + self._icon_id_prev_execution = ida_kernwin.load_custom_icon(data=icon_data) + + # describe a custom IDA UI action + action_desc = ida_kernwin.action_desc_t( + self.ACTION_PREV_EXECUTION, # The action name + "Go to previous execution", # The action text + IDACtxEntry(self._interactive_prev_execution), # The action handler + None, # Optional: action shortcut + "Go to the previous execution of the current address", # Optional: tooltip + self._icon_id_prev_execution # Optional: the action icon + ) + + # register the action with IDA + result = ida_kernwin.register_action(action_desc) + assert result, f"Failed to register '{action_desc.name}' action with IDA" + logger.info(f"Installed the '{action_desc.name}' menu entry") + + def _install_first_execution(self): + + # describe a custom IDA UI action + action_desc = ida_kernwin.action_desc_t( + self.ACTION_FIRST_EXECUTION, # The action name + "Go to first execution", # The action text + IDACtxEntry(self._interactive_first_execution), # The action handler + None, # Optional: action shortcut + "Go to the first execution of the current address", # Optional: tooltip + -1 # Optional: the action icon + ) + + # register the action with IDA + result = ida_kernwin.register_action(action_desc) + assert result, f"Failed to register '{action_desc.name}' action with IDA" + logger.info(f"Installed the '{action_desc.name}' menu entry") + + def _install_final_execution(self): + + # describe a custom IDA UI action + action_desc = ida_kernwin.action_desc_t( + self.ACTION_FINAL_EXECUTION, # The action name + "Go to final execution", # The action text + IDACtxEntry(self._interactive_final_execution), # The action handler + None, # Optional: action shortcut + "Go to the final execution of the current address", # Optional: tooltip + -1 # Optional: the action icon + ) + + # register the action with IDA + result = ida_kernwin.register_action(action_desc) + assert result, f"Failed to register '{action_desc.name}' action with IDA" + logger.info(f"Installed the '{action_desc.name}' menu entry") + + def _uninstall_load_trace(self): + + logger.info("Removing the 'Tenet trace file...' menu entry...") + + # remove the entry from the File-> menu + result = ida_kernwin.detach_action_from_menu( + "File/Load file/", + self.ACTION_LOAD_TRACE + ) + if not result: + logger.warning("Failed to detach action from menu...") + return False + + # unregister the action + result = ida_kernwin.unregister_action(self.ACTION_LOAD_TRACE) + if not result: + logger.warning("Failed to unregister action...") + return False + + # delete the entry's icon + #ida_kernwin.free_custom_icon(self._icon_id_file) # TODO + self._icon_id_file = ida_idaapi.BADADDR + + logger.info("Successfully removed the menu entry!") + return True + + def _uninstall_next_execution(self): + result = self._uninstall_action(self.ACTION_NEXT_EXECUTION, self._icon_id_next_execution) + self._icon_id_next_execution = ida_idaapi.BADADDR + return result + + def _uninstall_prev_execution(self): + result = self._uninstall_action(self.ACTION_PREV_EXECUTION, self._icon_id_prev_execution) + self._icon_id_prev_execution = ida_idaapi.BADADDR + return result + + def _uninstall_first_execution(self): + return self._uninstall_action(self.ACTION_FIRST_EXECUTION) + + def _uninstall_final_execution(self): + return self._uninstall_action(self.ACTION_FINAL_EXECUTION) + + def _uninstall_action(self, action, icon_id=ida_idaapi.BADADDR): + + result = ida_kernwin.unregister_action(action) + if not result: + logger.warning(f"Failed to unregister {action}...") + return False + + if icon_id != ida_idaapi.BADADDR: + ida_kernwin.free_custom_icon(icon_id) + + logger.info(f"Uninstalled the {action} menu entry") + return True + + #-------------------------------------------------------------------------- + # UI Event Handlers + #-------------------------------------------------------------------------- + + def _breakpoint_changed_hook(self, code, bpt): + """ + (Event) Breakpoint changed. + """ + + if code == ida_dbg.BPTEV_ADDED: + self._notify_ui_breakpoint_changed(bpt.ea, BreakpointEvent.ADDED) + + elif code == ida_dbg.BPTEV_CHANGED: + if bpt.enabled(): + self._notify_ui_breakpoint_changed(bpt.ea, BreakpointEvent.ENABLED) + else: + self._notify_ui_breakpoint_changed(bpt.ea, BreakpointEvent.DISABLED) + + elif code == ida_dbg.BPTEV_REMOVED: + self._notify_ui_breakpoint_changed(bpt.ea, BreakpointEvent.REMOVED) + + return 0 + + def _popup_hook(self, widget, popup): + """ + (Event) IDA is about to show a popup for the given TWidget. + """ + + # TODO: return if plugin/trace is not active + pass + + # fetch the (IDA) window type (eg, disas, graph, hex ...) + view_type = ida_kernwin.get_widget_type(widget) + + # only attach these context items to popups in disas views + if view_type == ida_kernwin.BWN_DISASMS: + + # prep for some shady hacks + p_qmenu = ctypes.cast(int(popup), ctypes.POINTER(ctypes.c_void_p))[0] + qmenu = sip.wrapinstance(int(p_qmenu), QtWidgets.QMenu) + + # + # inject and organize the Tenet plugin actions + # + + ida_kernwin.attach_action_to_popup( + widget, + popup, + self.ACTION_NEXT_EXECUTION, # The action ID (see above) + "Rename", # Relative path of where to add the action + ida_kernwin.SETMENU_APP # We want to append the action after ^ + ) + + # + # this is part of our bodge to inject a plugin action submenu + # at a specific location in the QMenu, cuz I don't think it's + # actually possible with the native IDA API's (for groups...) + # + + for action in qmenu.actions(): + if action.text() == "Go to next execution": + + # inject a group for the exta 'go to' actions + goto_submenu = QtWidgets.QMenu("Go to...") + qmenu.insertMenu(action, goto_submenu) + + # hold a Qt ref of the submenu so it doesn't GC + self.__goto_submenu = goto_submenu + break + + ida_kernwin.attach_action_to_popup( + widget, + popup, + self.ACTION_FIRST_EXECUTION, # The action ID (see above) + "Go to.../", # Relative path of where to add the action + ida_kernwin.SETMENU_APP # We want to append the action after ^ + ) + + ida_kernwin.attach_action_to_popup( + widget, + popup, + self.ACTION_FINAL_EXECUTION, # The action ID (see above) + "Go to.../", # Relative path of where to add the action + ida_kernwin.SETMENU_APP # We want to append the action after ^ + ) + + ida_kernwin.attach_action_to_popup( + widget, + popup, + self.ACTION_PREV_EXECUTION, # The action ID (see above) + "Rename", # Relative path of where to add the action + ida_kernwin.SETMENU_APP # We want to append the action after ^ + ) + + # + # inject a seperator to help insulate our plugin action group + # + + for action in qmenu.actions(): + if action.text() == "Go to previous execution": + qmenu.insertSeparator(action) + break + + def _render_lines(self, lines_out, widget, lines_in): + """ + (Event) IDA is about to render code viewer lines. + """ + widget_type = ida_kernwin.get_widget_type(widget) + + if widget_type == ida_kernwin.BWN_DISASM: + self._highlight_disassesmbly(lines_out, widget, lines_in) + + return + + def _highlight_disassesmbly(self, lines_out, widget, lines_in): + """ + TODO/XXX this is pretty gross + """ + ctx = self.get_context(IDA_GLOBAL_CTX) + if not ctx.reader: + return + + trail_length = 6 + + forward_color = self.palette.trail_forward + current_color = self.palette.trail_current + backward_color = self.palette.trail_backward + + r, g, b, _ = current_color.getRgb() + current_color = 0xFF << 24 | b << 16 | g << 8 | r + + step_over = False + modifiers = QtGui.QGuiApplication.keyboardModifiers() + step_over = bool(modifiers & QtCore.Qt.ShiftModifier) + + forward_ips = ctx.reader.get_next_ips(trail_length, step_over) + backward_ips = ctx.reader.get_prev_ips(trail_length, step_over) + + backward_trail, forward_trail = {}, {} + + trails = [ + (backward_ips, backward_trail, backward_color), + (forward_ips, forward_trail, forward_color) + ] + + for addresses, trail, color in trails: + for i, address in enumerate(addresses): + percent = 1.0 - ((trail_length - i) / trail_length) + + # convert to bgr + r, g, b, _ = color.getRgb() + ida_color = b << 16 | g << 8 | r + ida_color |= (0xFF - int(0xFF * percent)) << 24 + + # save the trail color + rebased_address = ctx.reader.analysis.rebase_pointer(address) + trail[rebased_address] = ida_color + + current_address = ctx.reader.rebased_ip + if not ida_bytes.is_mapped(current_address): + last_good_idx = ctx.reader.analysis.get_prev_mapped_idx(ctx.reader.idx) + if last_good_idx != -1: + + # fetch the last instruction pointer to fall within the trace + last_good_trace_address = ctx.reader.get_ip(last_good_idx) + + # convert the trace-based instruction pointer to one that maps to the disassembler + current_address = ctx.reader.analysis.rebase_pointer(last_good_trace_address) + + for section in lines_in.sections_lines: + for line in section: + address = line.at.toea() + + if address in backward_trail: + color = backward_trail[address] + elif address in forward_trail: + color = forward_trail[address] + elif address == current_address: + color = current_color + else: + continue + + entry = ida_kernwin.line_rendering_output_entry_t(line, ida_kernwin.LROEF_FULL_LINE, color) + lines_out.entries.push_back(entry) + + #---------------------------------------------------------------------- + # Callbacks + #---------------------------------------------------------------------- + + def ui_breakpoint_changed(self, callback): + register_callback(self._ui_breakpoint_changed_callbacks, callback) + + def _notify_ui_breakpoint_changed(self, address, code): + notify_callback(self._ui_breakpoint_changed_callbacks, address, code) + +#------------------------------------------------------------------------------ +# IDA UI Helpers +#------------------------------------------------------------------------------ + +class IDACtxEntry(ida_kernwin.action_handler_t): + """ + A minimal context menu entry class to utilize IDA's action handlers. + """ + + def __init__(self, action_function): + super(IDACtxEntry, self).__init__() + self.action_function = action_function + + def activate(self, ctx): + """ + Execute the embedded action_function when this context menu is invoked. + + NOTE: We pass 'None' to the action function to act as the ' + """ + self.action_function(IDA_GLOBAL_CTX) + return 1 + + def update(self, ctx): + """ + Ensure the context menu is always available in IDA. + """ + return ida_kernwin.AST_ENABLE_ALWAYS + +#------------------------------------------------------------------------------ +# IDA UI Event Hooks +#------------------------------------------------------------------------------ + +class DbgHooks(ida_dbg.DBG_Hooks): + def dbg_bpt_changed(self, code, bpt): + pass + +class UIHooks(ida_kernwin.UI_Hooks): + def get_lines_rendering_info(self, lines_out, widget, lines_in): + pass + def ready_to_run(self): + pass + def finish_populating_widget_popup(self, widget, popup): + pass \ No newline at end of file diff --git a/plugins_sogen-support/tenet/integration/ida_loader.py b/plugins_sogen-support/tenet/integration/ida_loader.py new file mode 100644 index 0000000..3d44614 --- /dev/null +++ b/plugins_sogen-support/tenet/integration/ida_loader.py @@ -0,0 +1,105 @@ +import time +import logging + +import ida_idaapi +import ida_kernwin + +from tenet.util.log import pmsg +from tenet.integration.ida_integration import TenetIDA + +logger = logging.getLogger("Tenet.IDA.Loader") + +#------------------------------------------------------------------------------ +# IDA Plugin Loader +#------------------------------------------------------------------------------ +# +# This file contains a stub 'plugin' class for the plugin as required by +# IDA Pro. Practically speaking, there should be little to *no* logic placed +# in this file because it is disassembler-specific. +# +# When IDA Pro is starting up, it will import all python files placed in its +# root plugin folder. It will then attempt to call PLUGIN_ENTRY() on each of +# the imported 'plugins'. We import PLUGIN_ENTRY into tenet_plugin.py +# so that IDA can see it. +# +# PLUGIN_ENTRY() is expected to return a plugin object (TenetIDAPlugin) +# derived from ida_idaapi.plugin_t. IDA will register the plugin, and +# interface with the plugin object to load / unload the plugin at certain +# times, per its configuration (flags, hotkeys). +# +# There should be virtually no reason for you to modify this file. +# + +def PLUGIN_ENTRY(): + """ + Required plugin entry point for IDAPython Plugins. + """ + return TenetIDAPlugin() + +class TenetIDAPlugin(ida_idaapi.plugin_t): + """ + The IDA plugin stub for Tenet. + """ + + # + # Plugin flags: + # - PLUGIN_MOD: The plugin may modify the database + # - PLUGIN_PROC: Load/unload the plugin when an IDB opens / closes + # - PLUGIN_HIDE: Hide the plugin from the IDA plugin menu + # + + flags = ida_idaapi.PLUGIN_PROC | ida_idaapi.PLUGIN_MOD | ida_idaapi.PLUGIN_HIDE + comment = "Trace Explorer" + help = "" + wanted_name = "Tenet" + wanted_hotkey = "" + + #-------------------------------------------------------------------------- + # IDA Plugin Overloads + #-------------------------------------------------------------------------- + + def init(self): + """ + This is called by IDA when it is loading the plugin. + """ + + try: + self.core = TenetIDA() + self.core.load() + except Exception as e: + pmsg("Failed to initialize Tenet") + logger.exception("Exception details:") + + # + # we return PLUGIN_KEEP here regardless of success/failure. this is to + # ensure that IDA will not try to reload the plugin again. + # + + return ida_idaapi.PLUGIN_KEEP + + def run(self, arg): + """ + This is called by IDA when this file is loaded as a script. + """ + ida_kernwin.warning("Tenet cannot be run as a script in IDA.") + + def term(self): + """ + This is called by IDA when it is unloading the plugin. + """ + logger.debug("IDA term started...") + + start = time.time() + logger.debug("-"*50) + + try: + self.core.unload() + self.core = None + except Exception as e: + logger.exception("Failed to cleanly unload Tenet from IDA.") + + end = time.time() + logger.debug("-"*50) + + logger.debug("IDA term done... (%.3f seconds...)" % (end-start)) + diff --git a/plugins_sogen-support/tenet/memory.py b/plugins_sogen-support/tenet/memory.py new file mode 100644 index 0000000..08b4293 --- /dev/null +++ b/plugins_sogen-support/tenet/memory.py @@ -0,0 +1,23 @@ +from tenet.hex import HexController + +#------------------------------------------------------------------------------ +# memory.py -- Memory Dump Controller +#------------------------------------------------------------------------------ +# +# The purpose of this file is to house the 'headless' components of the +# memory dump window and its underlying functionality. This is split into +# a model and controller component, of a typical 'MVC' design pattern. +# +# As our memory dumps are largely abstracted off a generic 'hex dump', +# there is very little code that actually has to be applied here (for now) +# + +class MemoryController(HexController): + """ + The Memory Dump Controller (Logic) + """ + + def __init__(self, pctx): + super(MemoryController, self).__init__(pctx) + self._title = "Memory View" + #self.model.hex_format = HexType.MAGIC diff --git a/plugins_sogen-support/tenet/registers.py b/plugins_sogen-support/tenet/registers.py new file mode 100644 index 0000000..7491b06 --- /dev/null +++ b/plugins_sogen-support/tenet/registers.py @@ -0,0 +1,376 @@ +from tenet.ui import * +from tenet.util.misc import register_callback, notify_callback +from tenet.integration.api import DockableWindow, disassembler + +#------------------------------------------------------------------------------ +# registers.py -- Register Controller +#------------------------------------------------------------------------------ +# +# The purpose of this file is to house the 'headless' components of the +# registers window and its underlying functionality. This is split into a +# model and controller component, of a typical 'MVC' design pattern. +# +# NOTE: for the time being, this file also contains the logic for the +# 'IDX Shell' as it is kind of attached to the register view and not big +# enough to demand its own seperate structuring ... yet +# + +class RegisterController(object): + """ + The Registers Controller (Logic) + """ + + def __init__(self, pctx): + self.pctx = pctx + self.model = RegistersModel(pctx) + self.reader = None + + # UI components + self.view = None + self.dockable = None + + # signals + self._ignore_signals = False + pctx.breakpoints.model.breakpoints_changed(self._breakpoints_changed) + + def show(self, target=None, position=0): + """ + Make the window attached to this controller visible. + """ + + # if there is no Qt (eg, our UI framework...) then there is no UI + if not QT_AVAILABLE: + return + + # the UI has already been created, and is also visible. nothing to do + if (self.dockable and self.dockable.visible): + return + + # + # if the UI has not yet been created, or has been previously closed + # then we are free to create new UI elements to take the place of + # anything that once was + # + + self.view = RegisterView(self, self.model) + new_dockable = DockableWindow("CPU Registers", self.view) + + # + # if there is a reference to a left over dockable window (e.g, from a + # previous close of this window type) steal its dock positon so we can + # hopefully take the same place as the old one + # + + if self.dockable: + new_dockable.copy_dock_position(self.dockable) + elif (target or position): + new_dockable.set_dock_position(target, position) + + # make the dockable/widget visible + self.dockable = new_dockable + self.dockable.show() + + def hide(self): + """ + Hide the window attached to this controller. + """ + + # if there is no view/dockable, then there's nothing to try and hide + if not(self.view and self.dockable): + return + + # hide the dockable, and drop references to the widgets + self.dockable.hide() + self.view = None + self.dockable = None + + def attach_reader(self, reader): + """ + Attach a trace reader to this controller. + """ + self.reader = reader + + # attach trace reader signals to this controller / window + reader.idx_changed(self._idx_changed) + + # + # directly call our event handler quick with the current idx since + # it's the first time we're seeing this. this ensures that our widget + # will accurately reflect the current state of the reader + # + + self._idx_changed(reader.idx) + + def detach_reader(self): + """ + Detach the active trace reader from this controller. + """ + self.reader = None + self.model.reset() + + def set_ip_breakpoint(self): + """ + Set an execution breakpoint on the current instruction pointer. + """ + current_ip = self.model.registers[self.model.arch.IP] + + self._ignore_signals = True + self.pctx.breakpoints.clear_execution_breakpoints() + self.pctx.breakpoints.add_execution_breakpoint(current_ip) + self._ignore_signals = False + + if self.view: + self.view.refresh() + + # TODO: maybe we can remove all these 'focus' funcs now? + def focus_register_value(self, reg_name): + """ + Focus a register value in the register view. + """ + self.model.focused_reg_value = reg_name + + def focus_register_name(self, reg_name): + """ + Focus a register name in the register view. + """ + self._clear_register_value_focus() + self.model.focused_reg_name = reg_name + + def clear_register_focus(self): + """ + Clear all focus on register fields. + """ + self._clear_register_value_focus() + self.model.focused_reg_name = None + + def follow_in_dump(self, reg_name): + """ + Follow a given register value in the memory dump. + """ + address = self.model.registers[reg_name] + self.pctx.memory.navigate(address) + + def _clear_register_value_focus(self): + """ + Clear focus from the active register field. + """ + self.model.focused_reg_value = None + + def set_registers(self, registers, delta=None): + """ + Set the registers for the view. + """ + self.model.set_registers(registers, delta) + + def evaluate_expression(self, expression): + """ + Evaluate the expression in the IDX Shell and navigate to it. + """ + + # a target idx was given as an integer + if isinstance(expression, int): + target_idx = expression + self.reader.seek(target_idx) + + # string handling + elif isinstance(expression, str): + + # blank string was passed from the shell, nothing to do... + if not expression: + return + + # a 'command' / alias idx was entered into the shell ('!...' prefix) + if expression[0] == '!': + self._handle_command(expression[1:]) + return + + # + # not a command, how about a comma seperated timestamp? + # -- e.g '5,218,121' + # + + idx_str = expression.replace(',', '') + try: + target_idx = int(idx_str) + except: + return + + self.reader.seek(target_idx) + + else: + raise ValueError(f"Unknown input expression type '{expression}'?!?") + + def _handle_command(self, expression): + """ + Handle the evaluation of commands on the timestamp shell. + """ + if self._handle_seek_percent(expression): + return True + if self._handle_seek_last(expression): + return True + return False + + def _handle_seek_percent(self, expression): + """ + Handle a 'percentage-based' trace seek. + + eg: !0, or !100 to skip to the start/end of trace + """ + try: + target_percent = float(expression) # float, so you could even do 42.1% + except: + return False + + # seek to the desired percentage in the trace + self.reader.seek_percent(target_percent) + return True + + def _handle_seek_last(self, expression): + """ + Handle a seek to the last mapped address. + """ + if expression != 'last': + return False + + last_idx = self.reader.trace.length - 1 + last_ip = self.reader.get_ip(last_idx) + rebased_ip = self.reader.analysis.rebase_pointer(last_ip) + + dctx = disassembler[self.pctx] + if not dctx.is_mapped(rebased_ip): + last_good_idx = self.reader.analysis.get_prev_mapped_idx(last_idx) + if last_good_idx == -1: + return False # navigation is just not gonna happen... + last_idx = last_good_idx + + # seek to the last known / good idx that is mapped within the disassembler + self.reader.seek(last_idx) + return True + + def _idx_changed(self, idx): + """ + The trace position has been changed. + """ + self.model.idx = idx + self.set_registers(self.reader.registers, self.reader.trace.get_reg_delta(idx).keys()) + + def _breakpoints_changed(self): + """ + Handle breakpoints changed event. + """ + if not self.view: + return + self.view.refresh() + + def _idx_changed(self, idx): + """ + The trace position has been changed. + """ + self.model.idx = idx + self.set_registers(self.reader.registers, self.reader.trace.get_reg_delta(idx).keys()) + + def _breakpoints_changed(self): + """ + Handle breakpoints changed event. + """ + if not self.view: + return + self.view.refresh() + +class RegistersModel(object): + """ + The Registers Model (Data) + """ + + def __init__(self, pctx): + self._pctx = pctx + self.reset() + + #---------------------------------------------------------------------- + # Callbacks + #---------------------------------------------------------------------- + + self._registers_changed_callbacks = [] + + #---------------------------------------------------------------------- + # Properties + #---------------------------------------------------------------------- + + @property + def arch(self): + """ + Return the architecture definition. + """ + return self._pctx.arch + + @property + def execution_breakpoints(self): + """ + Return the set of active execution breakpoints. + """ + return self._pctx.breakpoints.model.bp_exec + + #---------------------------------------------------------------------- + # Public + #---------------------------------------------------------------------- + + def reset(self): + + # the current timestamp in the trace + self.idx = -1 + + # the { reg_name: reg_value } dict of current register values + self.registers = {} + + # + # the names of the registers that have changed since the previous + # chronological timestamp in the trace. + # + # for example if you singlestep forward, any registers that changed as + # a result of 'normal execution' may be highlighted (e.g. red) + # + + self.delta_trace = [] + + # + # the names of registers that have changed since the last navigation + # event (eg, skipping between breakpoints, memory accesses). + # + # this is used to highlight registers that may not have changed as a + # result of the previous chronological trace event, but by means of + # user navigation within tenet. + # + + self.delta_navigation = [] + + self.focused_reg_name = None + self.focused_reg_value = None + + def set_registers(self, registers, delta=None): + + # compute which registers changed as a result of navigation + unchanged = dict(set(self.registers.items()) & set(registers.items())) + self.delta_navigation = set([k for k in registers if k not in unchanged]) + + # save the register delta that changed since the previous trace timestamp + self.delta_trace = delta if delta else [] + self.registers = registers + + # notify the UI / listeners of the model that an update occurred + self._notify_registers_changed() + + #---------------------------------------------------------------------- + # Callbacks + #---------------------------------------------------------------------- + + def registers_changed(self, callback): + """ + Subscribe a callback for a registers changed event. + """ + register_callback(self._registers_changed_callbacks, callback) + + def _notify_registers_changed(self): + """ + Notify listeners of a registers changed event. + """ + notify_callback(self._registers_changed_callbacks) diff --git a/plugins_sogen-support/tenet/stack.py b/plugins_sogen-support/tenet/stack.py new file mode 100644 index 0000000..5d104bb --- /dev/null +++ b/plugins_sogen-support/tenet/stack.py @@ -0,0 +1,105 @@ +import struct + +from tenet.ui import * +from tenet.hex import HexController +from tenet.types import HexType, AuxType + +#------------------------------------------------------------------------------ +# stack.py -- Stack Dump Controller +#------------------------------------------------------------------------------ +# +# The purpose of this file is to house the 'headless' components of the +# stack dump window and its underlying functionality. This is split into +# a model and controller component, of a typical 'MVC' design pattern. +# +# The stack dump window abstracts from a simple hex dump. We use the code +# below to configure our underlying hex dump to appear more like a typical +# stack view might instead. +# + +class StackController(HexController): + """ + The Stack Dump Controller (Logic) + """ + + def __init__(self, pctx): + super(StackController, self).__init__(pctx) + self._title = "Stack View" + + def attach_reader(self, reader): + """ + Attach a trace reader, and configure the view to model a stack. + """ + self.model.num_bytes_per_line = reader.arch.POINTER_SIZE + self.model.hex_format = HexType.DWORD if reader.arch.POINTER_SIZE == 4 else HexType.QWORD + self.model.aux_format = AuxType.STACK + super(StackController, self).attach_reader(reader) + + def follow_in_dump(self, stack_address): + """ + Follow the pointer at a given stack address in the memory dump. + """ + POINTER_SIZE = self.pctx.reader.arch.POINTER_SIZE + + # align the given stack address (which we will read..) + stack_address &= ~(POINTER_SIZE - 1) + + # + # compute the relative index of the stack entry, which we will + # use to carve data from the currently visible stack model + # + + relative_index = stack_address - self.model.address + + # attempt to carve the data and validity mask from the stack model + try: + data = self.model.data[relative_index:relative_index+POINTER_SIZE] + mask = self.model.mask[relative_index:relative_index+POINTER_SIZE] + except: + return False + + # ensure the carved data is fully resolved (e.g. there are no unknown bytes) + if not (len(mask) == POINTER_SIZE and list(set(mask)) == [0xFF]): + return False + + # unpack the carved data as a pointer + parsed_address = struct.unpack("I" if POINTER_SIZE == 4 else "Q", data)[0] + + # navigate the memory dump window to the 'pointer' we carved off the stack + self.pctx.memory.navigate(parsed_address) + + def _idx_changed(self, idx): + """ + Override the default hex view idx changed event handler. + """ + + # fade out the upper part of the stack that is currently 'unallocated' + self.set_fade_threshold(self.reader.sp) + + if self.view: + + # + # if the user has a byte / range selected or the view is purposely + # omitting navigation events, we will *not* move the stack view on + # idx changes. + # + # this is to preserve the location of their selection on-screen + # (eg, when hovering a selected byte, and jumping between its + # memory accesses) + # + + if self.view._ignore_navigation or self.view.selection_size: + self.refresh_memory() + self.view.refresh() + return + + # + # if there is no special user interaction going on with the stack + # view, we will simply ensure that the stack stays 'pinned' to the + # top of the stack, per the current trace reader state. + # + # we conciously chose to show '3' lines of the unallocated frames + # to provide a bit more awarness to pops/rets as they happen + # + + self.navigate(self.reader.sp - self.model.num_bytes_per_line * 3) \ No newline at end of file diff --git a/plugins_sogen-support/tenet/trace/__init__.py b/plugins_sogen-support/tenet/trace/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins_sogen-support/tenet/trace/__pycache__/__init__.cpython-311.pyc b/plugins_sogen-support/tenet/trace/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..666e00e Binary files /dev/null and b/plugins_sogen-support/tenet/trace/__pycache__/__init__.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/trace/__pycache__/analysis.cpython-311.pyc b/plugins_sogen-support/tenet/trace/__pycache__/analysis.cpython-311.pyc new file mode 100644 index 0000000..5a139d4 Binary files /dev/null and b/plugins_sogen-support/tenet/trace/__pycache__/analysis.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/trace/__pycache__/file.cpython-311.pyc b/plugins_sogen-support/tenet/trace/__pycache__/file.cpython-311.pyc new file mode 100644 index 0000000..f53bba1 Binary files /dev/null and b/plugins_sogen-support/tenet/trace/__pycache__/file.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/trace/__pycache__/reader.cpython-311.pyc b/plugins_sogen-support/tenet/trace/__pycache__/reader.cpython-311.pyc new file mode 100644 index 0000000..02fe3aa Binary files /dev/null and b/plugins_sogen-support/tenet/trace/__pycache__/reader.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/trace/__pycache__/types.cpython-311.pyc b/plugins_sogen-support/tenet/trace/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000..91b811c Binary files /dev/null and b/plugins_sogen-support/tenet/trace/__pycache__/types.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/trace/analysis.py b/plugins_sogen-support/tenet/trace/analysis.py new file mode 100644 index 0000000..0820c5f --- /dev/null +++ b/plugins_sogen-support/tenet/trace/analysis.py @@ -0,0 +1,253 @@ +import bisect +import collections + +from tenet.util.log import pmsg + +#----------------------------------------------------------------------------- +# analysis.py -- Trace Analysis +#----------------------------------------------------------------------------- +# +# This file should contain logic to further process, augment, optimize or +# annotate Tenet traces when a binary analysis framework such as IDA / +# Binary Ninja is available to a trace reader. +# +# As of now (v0.2) the only added analysis we do is to try and map +# ASLR'd trace addresses to executable opened in the database. +# +# In the future, I imagine this file will be used to indexing events +# such as function calls, returns, entry and exit to unmapped regions, +# service pointer annotations, and much more. +# + +class TraceAnalysis(object): + """ + A high level, debugger-like interface for querying Tenet traces. + """ + + def __init__(self, trace, dctx): + self._dctx = dctx + self._trace = trace + self._remapped_regions = [] + self._unmapped_entry_points = [] + self.slide = None + self._analyze() + + #------------------------------------------------------------------------- + # Public + #------------------------------------------------------------------------- + + def rebase_pointer(self, address): + """ + Return a rebased version of the given address, if one exists. + """ + for m1, m2 in self._remapped_regions: + #print(f"m1 start: {m1[0]:08X} address: {address:08X} m1 end: {m1[1]:08X}") + #print(f"m2 start: {m2[0]:08X} address: {address:08X} m2 end: {m2[1]:08X}") + if m1[0] <= address <= m1[1]: + return address + (m2[0] - m1[0]) + if m2[0] <= address <= m2[1]: + return address - (m2[0] - m1[0]) + return address + + def get_prev_mapped_idx(self, idx): + """ + Return the previous idx to fall within a mapped code region. + """ + index = bisect.bisect_right(self._unmapped_entry_points, idx) - 1 + try: + return self._unmapped_entry_points[index] + except IndexError: + return -1 + + #------------------------------------------------------------------------- + # Analysis + #------------------------------------------------------------------------- + + def _analyze(self): + """ + Analyze the trace against the binary loaded by the disassembler. + """ + self._analyze_aslr() + self._analyze_unmapped() + + def _analyze_aslr(self): + """ + Analyze trace execution to resolve ASLR mappings against the disassembler. + """ + dctx, trace = self._dctx, self._trace + + # get *all* of the instruction addresses from disassembler + instruction_addresses = dctx.get_instruction_addresses() + + # + # bucket the instruction addresses from the disassembler + # based on non-aslr'd bits (lower 12 bits, 0xFFF) + # + + binary_buckets = collections.defaultdict(list) + for address in instruction_addresses: + bits = address & 0xFFF + binary_buckets[bits].append(address) + + # get the set of unique, executed addresses from the trace + trace_addresses = trace.ip_addrs + + # + # scan the executed addresses from the trace, and discard + # any that cannot be bucketed by the non ASLR-d bits that + # match the open executable + # + + trace_buckets = collections.defaultdict(list) + for executed_address in trace_addresses: + bits = executed_address & 0xFFF + if bits not in binary_buckets: + continue + trace_buckets[bits].append(executed_address) + + # + # this is where things get a little bit interesting. we compute the + # distance between addresses in the trace and disassembler buckets + # + # the distance that appears most frequently is likely to be the ASLR + # slide to align the disassembler imagebase and trace addresses + # + + slide_buckets = collections.defaultdict(list) + for bits, bin_addresses in binary_buckets.items(): + for executed_address in trace_buckets[bits]: + for disas_address in bin_addresses: + distance = disas_address - executed_address + slide_buckets[distance].append(executed_address) + + # basically the executable 'range' of the open binary + disas_low_address = instruction_addresses[0] + disas_high_address = instruction_addresses[-1] + + # convert to set for O(1) lookup in following loop + instruction_addresses = set(instruction_addresses) + + # + # loop through all the slide buckets, from the most frequent distance + # (ASLR slide) to least frequent. the goal now is to sanity check the + # ranges to find one that seems to couple tightly with the disassembler + # + + for k in sorted(slide_buckets, key=lambda k: len(slide_buckets[k]), reverse=True): + expected = len(slide_buckets[k]) + + # + # TODO: uh, if it's getting this small, I don't feel comfortable + # selecting an ASLR slide. the user might be loading a tiny trace + # with literally 'less than 10' unique instructions (?) that + # would map to the database + # + + if expected < 10: + continue + + hit, seen = 0, 0 + for address in trace_addresses: + + # add the ASLR slide for this bucket to a traced address + rebased_address = address + k + + # the rebased address seems like it falls within the disassembler ranges + if disas_low_address <= rebased_address < disas_high_address: + seen += 1 + + # but does the address *actually* exist in the disassembler? + if rebased_address in instruction_addresses: + hit += 1 + + # + # the first *high* hit ratio is almost certainly the correct + # ASLR, practically speaking this should probably be 1.00, but + # I lowered it a bit to give a bit of flexibility. + # + # NOTE/TODO: a lower 'hit' ratio *could* occur if a lot of + # undefined instruction addresses in the disassembler get + # executed in the trace. this could be packed code / malware, + # in which case we will have to perform more aggressive analysis + # + + if (hit / seen) > 0.95: + #print(f"ASLR Slide: {k:08X} Quality: {hit/seen:0.2f} (h {hit} s {seen} e {expected})") + slide = k + break + + # + # if we do not break from the loop, we failed to find an adequate + # slide, which is very bad. + # + # NOTE/TODO: uh what do we do if we fail the ASLR slide? + # + + else: + self.slide = None + return False + + # + # TODO: err, lol this is all kind of dirty. should probably refactor + # and clean up this whole 'remapped_regions' stuff. + # + + m1 = [disas_low_address, disas_high_address] + + if slide < 0: + m2 = [m1[0] - slide, m1[1] - slide] + else: + m2 = [m1[0] + slide, m1[1] + slide] + + self.slide = slide + self._remapped_regions.append((m1, m2)) + + return True + + def _analyze_unmapped(self): + """ + Analyze trace execution to identify entry/exit to unmapped segments. + """ + if self.slide is None: + return + + # alias for readability and speed + trace, ips = self._trace, self._trace.ip_addrs + lower_mapped, upper_mapped = self._remapped_regions[0][1] + + # + # for speed, pull out the 'compressed' ip indexes that matched mapped + # (known) addresses within the disassembler context + # + + mapped_ips = set() + for i, address in enumerate(ips): + if lower_mapped <= address <= upper_mapped: + mapped_ips.add(i) + + last_good_idx = 0 + unmapped_entries = [] + + # loop through each segment in the trace + for seg in trace.segments: + seg_ips = seg.ips + seg_base = seg.base_idx + + # loop through each executed instruction in this segment + for relative_idx in range(0, seg.length): + compressed_ip = seg_ips[relative_idx] + + # the current instruction is in an unmapped region + if compressed_ip not in mapped_ips: + + # if we were in a known/mapped region previously, then save it + if last_good_idx: + unmapped_entries.append(last_good_idx) + last_good_idx = 0 + + # if we are in a good / mapped region, update our current idx + else: + last_good_idx = seg_base + relative_idx + + #print(f" - Unmapped Entry Points: {len(unmapped_entries)}") + self._unmapped_entry_points = unmapped_entries diff --git a/plugins_sogen-support/tenet/trace/arch/__init__.py b/plugins_sogen-support/tenet/trace/arch/__init__.py new file mode 100644 index 0000000..791fd3d --- /dev/null +++ b/plugins_sogen-support/tenet/trace/arch/__init__.py @@ -0,0 +1,2 @@ +from .x86 import ArchX86 +from .amd64 import ArchAMD64 \ No newline at end of file diff --git a/plugins_sogen-support/tenet/trace/arch/__pycache__/__init__.cpython-311.pyc b/plugins_sogen-support/tenet/trace/arch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..e883137 Binary files /dev/null and b/plugins_sogen-support/tenet/trace/arch/__pycache__/__init__.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/trace/arch/__pycache__/amd64.cpython-311.pyc b/plugins_sogen-support/tenet/trace/arch/__pycache__/amd64.cpython-311.pyc new file mode 100644 index 0000000..f5fb1aa Binary files /dev/null and b/plugins_sogen-support/tenet/trace/arch/__pycache__/amd64.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/trace/arch/__pycache__/x86.cpython-311.pyc b/plugins_sogen-support/tenet/trace/arch/__pycache__/x86.cpython-311.pyc new file mode 100644 index 0000000..8388355 Binary files /dev/null and b/plugins_sogen-support/tenet/trace/arch/__pycache__/x86.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/trace/arch/amd64.py b/plugins_sogen-support/tenet/trace/arch/amd64.py new file mode 100644 index 0000000..5b4a741 --- /dev/null +++ b/plugins_sogen-support/tenet/trace/arch/amd64.py @@ -0,0 +1,31 @@ +class ArchAMD64: + """ + AMD64 CPU Architecture Definition. + """ + MAGIC = 0x41424344 + + POINTER_SIZE = 8 + + IP = "RIP" + SP = "RSP" + + REGISTERS = \ + [ + "RAX", + "RBX", + "RCX", + "RDX", + "RBP", + "RSP", + "RSI", + "RDI", + "R8", + "R9", + "R10", + "R11", + "R12", + "R13", + "R14", + "R15", + "RIP" + ] \ No newline at end of file diff --git a/plugins_sogen-support/tenet/trace/arch/x86.py b/plugins_sogen-support/tenet/trace/arch/x86.py new file mode 100644 index 0000000..62ac474 --- /dev/null +++ b/plugins_sogen-support/tenet/trace/arch/x86.py @@ -0,0 +1,23 @@ +class ArchX86: + """ + x86 CPU Architecture Definition. + """ + MAGIC = 0x386 + + POINTER_SIZE = 4 + + IP = "EIP" + SP = "ESP" + + REGISTERS = \ + [ + "EAX", + "EBX", + "ECX", + "EDX", + "EBP", + "ESP", + "ESI", + "EDI", + "EIP" + ] \ No newline at end of file diff --git a/plugins_sogen-support/tenet/trace/file.py b/plugins_sogen-support/tenet/trace/file.py new file mode 100644 index 0000000..5f26150 --- /dev/null +++ b/plugins_sogen-support/tenet/trace/file.py @@ -0,0 +1,1792 @@ +import os +import time +import zlib +import array +import bisect +import ctypes +import struct +import zipfile +import binascii +import itertools +import collections + +#----------------------------------------------------------------------------- +# file.py -- Trace File +#----------------------------------------------------------------------------- +# +# NOTE/PREFACE: Please be aware, this is a 100% prototype implementation +# of a basic trace log file specification. It has not been designed with +# exhaustive attention to scalability + performance for use-cases that +# exceed the recommended 'maximum' of 10,000,000 (10m) instructions. +# +# There are no dependencies. There is no multiprocessing. This is will +# be a nightmare to maintain or scale further. It is 100% meant to be +# thrown away in favor of a native backend. +# +# -------------- +# +# This file contains the 'trace file' implementation for the plugin. It +# is responsible for the loading / processing of raw text traces, providing +# a few 'low level' APIs for querying information out of the lifted trace. +# +# When loading a text trace, this code will also do some basic compression +# of the trace to reduce both its on-disk and in-memory footprint. It will +# also perform some basic 'indexing' of the trace and its contents to make +# it more performant to search and query by the 'high level' trace reader. +# +# Upon completion, the indexed+compressed trace file is saved to disk +# alongside the original trace, with the '.tt' (Tenet Trace) file +# extension. This original trace can be discarded by the user. +# +# The processed trace can be loaded and used in a fraction of the time +# versus the raw text trace. The trace file implementation will also seek +# out a matching file name with the '.tt' file extension, and prioritize +# loading that over a raw text trace. +# + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +# +# attempt plugin imports, assuming this file is being run / loaded in +# the context of the plugin running within a disassembler +# + +try: + from tenet.util.log import pmsg + from tenet.util.rebase import rebase_database + from tenet.trace.arch import ArchAMD64, ArchX86 + from tenet.trace.types import TraceMemory + + # + # this script can technically be run in a standalone mone to process / digest + # a trace outside of a disassembler / the normal integration. so if the above + # fails, use the following imports to operate independently + # + +except ImportError: + from arch import ArchAMD64, ArchX86 + from .types import TraceMemory + # a little weird, but this makes 'rebase_database' a no-op outside IDA + rebase_database = lambda x: True + pmsg = print +#----------------------------------------------------------------------------- +# Definitions +#----------------------------------------------------------------------------- + +BYTE_MAX = (1 << 8) - 1 +USHRT_MAX = (1 << 16) - 1 +UINT_MAX = (1 << 32) - 1 +ULLONG_MAX = (1 << 64) - 1 + +TRACE_MEM_READ = 0 +TRACE_MEM_WRITE = 1 + +# +# NOTE: some of this stuff is probably broken / cannot be easily toggled +# anymore, so I wouldn't actually suggest playing around with them as things +# will probably break or behave erratically +# + +TRACE_STATS = False + +#DEFAULT_COMPRESSION = zipfile.ZIP_BZIP2 +#DEFAULT_COMPRESSION = zipfile.ZIP_LZMA +DEFAULT_COMPRESSION = zipfile.ZIP_DEFLATED + +#DEFAULT_SEGMENT_LENGTH = 250_000 +#DEFAULT_SEGMENT_LENGTH = 1_000_000 +DEFAULT_SEGMENT_LENGTH = USHRT_MAX +REG_OFFSET_CACHE_SIZE = 16 +REG_OFFSET_CACHE_INTERVAL = 4096 + +#----------------------------------------------------------------------------- +# Utils +#----------------------------------------------------------------------------- + +def hash_file(filepath): + """ + Return a CRC32 of the file at the given path. + """ + crc = 0 + with open(filepath, 'rb', 65536) as ins: + for x in range(int((os.stat(filepath).st_size / 65536)) + 1): + crc = zlib.crc32(ins.read(65536), crc) + return (crc & 0xFFFFFFFF) + +def number_of_bits_set(i): + """ + Count the number of bits set in the given 32bit integer. + """ + i = i - ((i >> 1) & 0x55555555) + i = (i & 0x33333333) + ((i >> 2) & 0x33333333) + return (((i + (i >> 4) & 0xF0F0F0F) * 0x1010101) & 0xffffffff) >> 24 + +def width_from_type(t): + """ + Return the byte width of a python 'struct' type definition. + """ + if t == 'B': + return 1 + elif t == 'H': + return 2 + elif t == 'I': + return 4 + elif t == 'Q': + return 8 + raise ValueError(f"Invalid type '{t}'") + +def type_from_width(width): + """ + Return an appropriate integer type for the given byte width. + """ + if width == 1: + return 'B' + elif width == 2: + return 'H' + elif width == 4: + return 'I' + elif width == 8: + return 'Q' + raise ValueError(f"Invalid type width {width}") + +def type_from_limit(limit): + """ + Return an appropriate integer type for the maximum given value. + """ + if limit <= BYTE_MAX: + return 'B' + elif limit <= USHRT_MAX: + return 'H' + elif limit <= UINT_MAX: + return 'I' + elif limit <= ULLONG_MAX: + return 'Q' + raise ValueError(f"Limit {limit:,} exceeds maximum type") + +#----------------------------------------------------------------------------- +# Serialization Structures +#----------------------------------------------------------------------------- + +class TraceInfo(ctypes.Structure): + _pack_ = 1 + _fields_ = [ + ('arch_magic', ctypes.c_uint32), + ('ip_num', ctypes.c_uint32), + ('mem_addrs_num', ctypes.c_uint32), + ('mask_num', ctypes.c_uint32), + ('mem_idx_width', ctypes.c_uint8), + ('mem_addr_width', ctypes.c_uint8), + ('original_hash', ctypes.c_uint32), + ('module_base', ctypes.c_uint64), + ] + +class SegmentInfo(ctypes.Structure): + _pack_ = 1 + _fields_ = [ + ('id', ctypes.c_uint32), + ('base_idx', ctypes.c_uint32), + ('length', ctypes.c_uint32), + + ('ip_num', ctypes.c_uint32), + ('ip_length', ctypes.c_uint32), + + ('reg_mask_num', ctypes.c_uint32), + ('reg_mask_length', ctypes.c_uint32), + ('reg_data_length', ctypes.c_uint32), + + ('mem_read_num', ctypes.c_uint32), + ('mem_read_data_length', ctypes.c_uint32), + + ('mem_write_num', ctypes.c_uint32), + ('mem_write_data_length', ctypes.c_uint32), + ] + +class MemValue(ctypes.Structure): + _pack_ = 1 + _fields_ = [ + ('mask', ctypes.c_uint8), + ('value', ctypes.c_uint8 * 8) + ] + +#----------------------------------------------------------------------------- +# Trace File +#----------------------------------------------------------------------------- + +class TraceFile(object): + """ + An interface to load and query data directly from a trace file. + """ + + def __init__(self, filepath, arch=None): + self.filepath = filepath + self.arch = arch + + # + # TODO: really, the trace file should auto-detect arch imo but i'll + # do that at a later date... + # + + if not self.arch: + self.arch = ArchAMD64() + + # a sorted array of all unique PC / IP (eg, EIP, or RIP) that appear in the trace + self.ip_addrs = None + + # + # mem_addrs: a sorted array of all unique memory addresses referenced + # in the trace (8-byte aligned) + # + # mem_masks: a sorted array of byte masks that correspond with the addrs + # array described above. each entry in this array is a single 8bit mask, + # where each bit specifies if that memory address was accessed over the + # course of the entire trace + # + # e.g: + # mem_addrs[924] = 0x401448 (an 8-byte aligned memory address) + # mem_masks[924] = 0x0F (a 'mask' of what bytes exist in the trace) + # | + # |_ a bit mask of 00001111 + # + # In this example, we know that 0x401448 --> 0x40144C were either read + # or written at some point in this trace. + # + # The alignment of pointers helps with basic id-based compression as + # these pointer id / 'mapped addresses' are used across the segments. + # + # The masks effectively create a global bitmap of all addresses that + # actually appear in the trace, allowing certain addresses to be + # immediately discarded from memory queries. This can dramatically + # reduce search complexity. + # + + self.mem_addrs = None + self.mem_masks = None + + # + # register data is stored in a contiguos blob for each trace segment. + # + # for each step / 'instruction' of the trace, we create a 32bit + # register mask that defines which registers changed. each bit in + # the mask defines 1 unique CPU register, and its position in the + # mask specifies which one it is. + # + # this will contain a list of each unique register delta mask that + # appears in the trace. instead of storing 32bit mask for each step + # of the trace, we use this table to translate a 8bit ID (an index) + # into this table of unique register masks (self.masks) + # + + self.masks = [] # TODO: rename to register_masks or something... + + # an O(1) lookup table for the 'byte size' of each register mask + self.mask_sizes = [] + + # + # a trace is broken up into segments of 64k instructions. each of + # theses segments will have small indexes / summaries embedded in + # them to make them easier to search or ignore as applicable + # + # for more information, look at the TraceSegments class + # + + self.segments = [] + + # the number of timestamps / 'instructions' for each trace segment + self.segment_length = DEFAULT_SEGMENT_LENGTH + + # the hash of the original / source log file + self.original_hash = None + + # the module base as specified in the text trace + self.module_base = 0 + + # + # now that you have some idea of how the trace file is going to be + # organized... let's actually go and try to load one + # + + self._load_trace() + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def name(self): + """ + Return the name of the trace. + """ + return os.path.basename(self.filepath) + + @property + def packed_name(self): + """ + Return the packed trace filename. + """ + root, ext = os.path.splitext(self.name) + return f"{root}.tt" + + @property + def packed_filepath(self): + """ + Return the packed trace filepath. + """ + directory = os.path.dirname(self.filepath) + return os.path.join(directory, self.packed_name) + + @property + def length(self): + """ + Return the length of the trace. (e.g, # instructions executed) + """ + if not self.segments: + return 0 + return self.segments[-1].base_idx + self.segments[-1].length + + #------------------------------------------------------------------------- + # Public + #------------------------------------------------------------------------- + + # + # I should really define this somewhere more notable... but throughout + # this project you will see the term 'idx', this is a simple abbreviation + # of 'index' that I used early on and kind of stuck with. + # + # an idx is simply an integer, that repersents a unique 'timestamp' in + # the trace file. eg, idx 0 is the start of the trace, idx 100 is + # equivalent to 100 steps into the trace, etc... + # + # an idx label attached to any sort of variable / definition in this + # codebase suggest that variable is a trace 'timestamp' / position! + # + + def get_reg_delta(self, idx): + """ + Return the register delta for a given timestamp. + """ + seg = self.get_segment(idx) + if not seg: + return {} + return seg.get_reg_delta(idx) + + def get_read_delta(self, idx): + """ + Return the memory read delta for a given timestamp. + """ + seg = self.get_segment(idx) + if not seg: + return {} + return seg.get_read_delta(idx) + + def get_write_delta(self, idx): + """ + Return the memory write delta for a given timestamp. + """ + seg = self.get_segment(idx) + if not seg: + return {} + return seg.get_write_delta(idx) + + def get_segment(self, idx): + """ + Return the trace segment for a given timestamp. + """ + for seg in self.segments: + if seg.base_idx <= idx < seg.base_idx + seg.length: + return seg + return None + + def get_reg_mask_ids_containing(self, reg_name): + """ + Return a set of reg mask ids containing the given register name. + """ + reg_id = self.arch.REGISTERS.index(reg_name.upper()) + reg_mask = 1 << reg_id + + found = set() + for i, current_mask in enumerate(self.masks): + if current_mask & reg_mask: + found.add(i) + + return found + + #------------------------------------------------------------------------- + # Save / Serialization + #------------------------------------------------------------------------- + + def _save(self): + """ + Save the packed trace to disk. + """ + + with zipfile.ZipFile(self.packed_filepath, 'w', compression=DEFAULT_COMPRESSION) as zip_archive: + self._save_header(zip_archive) + self._save_segments(zip_archive) + + self.filepath = self.packed_filepath + + def _save_header(self, zip_archive): + """ + Save the trace header to the packed trace. + """ + + # populate the trace header + header = TraceInfo() + header.arch_magic = self.arch.MAGIC + header.ip_num = len(self.ip_addrs) + header.mem_addrs_num = len(self.mem_addrs) + header.mask_num = len(self.masks) + header.mem_idx_width = width_from_type(self.mem_idx_type) + header.mem_addr_width = width_from_type(self.mem_addr_type) + header.original_hash = self.original_hash + header.module_base = self.module_base + mask_data = (ctypes.c_uint32 * len(self.masks))(*self.masks) + + # save the global trace data / header to the zip + with zip_archive.open('header', 'w') as f: + f.write(bytearray(header)) + f.write(bytearray(self.ip_addrs)) + f.write(bytearray(self.mem_addrs)) + f.write(bytearray(self.mem_masks)) + f.write(bytearray(mask_data)) + + def _save_segments(self, zip_archive): + """ + Save the trace segments to the packed trace. + """ + for segment in self.segments: + with zip_archive.open(f'segments/{segment.id}', 'w') as f: + segment.dump(f) + + #------------------------------------------------------------------------- + # Load / Deserialization + #------------------------------------------------------------------------- + + def _load_trace(self): + """ + Load a trace from disk. + + NOTE: THIS ROUTINE WILL ATTEMPT TO LOAD A PACKED TRACE INSTEAD OF A + SELECTED RAW TEXT TRACE IF IT FINDS ONE AVAILABLE!!! + """ + + # the user probably selected a '.tt' trace + if zipfile.is_zipfile(self.filepath): + self._load_packed_trace(self.filepath) + return + + # + # the user selected a '.txt' trace, but there is a '.tt' packed trace + # beside it, so let's check if the packed trace matches the text trace + # + + if zipfile.is_zipfile(self.packed_filepath): + packed_crc = self._fetch_hash(self.packed_filepath) + text_crc = hash_file(self.filepath) + + # + # the crc in the packed file seems to match the selected text log, + # so let's just load the packed trace as it should be faster + # + + if packed_crc == text_crc: + self._load_packed_trace(self.packed_filepath) + return + + # + # no luck loading / side-loading packed traces, so simply try to + # load the user selected trace as a normal text Tenet trace + # + + self._load_text_trace(self.filepath) + + def _load_packed_trace(self, filepath): + """ + Load a packed trace from disk. + """ + with zipfile.ZipFile(filepath, 'r') as zip_archive: + self._load_header(zip_archive) + self._load_segments(zip_archive) + + self.filepath = filepath + + if self.module_base: + if not rebase_database(self.module_base): + pmsg("Database rebase failed or was cancelled.") + + def _select_arch(self, magic): + """ + TODO: Select the trace CPU arch based on the given magic value. + """ + if ArchAMD64.MAGIC == magic: + self.arch = ArchAMD64() + else: + self.arch = ArchX86() + + def _fetch_hash(self, filepath): + """ + Return the original file hash (CRC32) from the given packed trace filepath. + """ + header = TraceInfo() + with zipfile.ZipFile(filepath, 'r') as zip_archive: + with zip_archive.open('header', 'r') as f: + f.readinto(header) + return header.original_hash + + def _load_header(self, zip_archive): + """ + Load the trace header from a packed trace. + """ + header = TraceInfo() + + with zip_archive.open('header', 'r') as f: + + # read the main trace info from the packed trace header + f.readinto(header) + + # select the cpu / arch for this trace + #print(f"Loading magic 0x{header.arch_magic:08X}") + self._select_arch(header.arch_magic) + + # load the (sorted) ip address table from disk + self.ip_addrs = array.array(type_from_width(self.arch.POINTER_SIZE)) + self.ip_addrs.fromfile(f, header.ip_num) + + # ('mem_idx_width', ctypes.c_uint8), + # ('mem_addr_width', ctypes.c_uint8), + self.mem_idx_type = type_from_width(header.mem_idx_width) + self.mem_addr_type = type_from_width(header.mem_addr_width) + #self.mem_mask_width = type_from_width(header.mem_mask_width) + + # load the (sorted, aligned) mem table from disk + self.mem_addrs = array.array(type_from_width(self.arch.POINTER_SIZE)) + self.mem_addrs.fromfile(f, header.mem_addrs_num) + self.mem_masks = array.array('B') + self.mem_masks.fromfile(f, header.mem_addrs_num) + + # ('mask_num', ctypes.c_uint32), + self.masks = array.array('I') + self.masks.fromfile(f, header.mask_num) + self.mask_sizes = [number_of_bits_set(mask) * self.arch.POINTER_SIZE for mask in self.masks] + + # source file hash + self.original_hash = header.original_hash + self.module_base = header.module_base + + def _load_segments(self, zip_archive): + """ + Load the trace segments from the packed trace. + """ + + for path in zip_archive.namelist(): + + # skip anything that is not a trace segment + if not (path.startswith('segments/') and path[-1] != '/'): + continue + + # load a trace segment from the packed trace file + with zip_archive.open(path, 'r') as f: + segment = TraceSegment(self) + segment.from_file(f) + + # save the segment to the trace + self.segments.append(segment) + + # sort the loaded segments by id (just in case) + self.segments.sort(key=lambda x: x.id) + + def _load_text_trace(self, filepath): + """ + Load a text trace from disk. + """ + idx = 0 + + # mappings of address/mask and their mapped (compressed) id + # - NOTE: these are only used when converting traces from text to binary + self.ip_map = collections.OrderedDict() + self.mem_map = collections.OrderedDict() + self.mask2mapped = {} + self.masks = [] + + # TODO: detect arch based on reg / lines in file + #if not self.arch: + # self._select_arch(0) + + # hash (CRC32) the source / text filepath before loading it + self.original_hash = hash_file(filepath) + + # load / parse a text trace into trace segments + with open(filepath, 'r') as f: + + # + # before processing a text trace, we will check the first line for a + # special 'mb=' (module base) tag. if this tag exists, we will use + # it to rebase the underlying database before continuing to parse + # and process the trace data + # + + # + # before processing a text trace, we will check the first line for a + # special 'mb=' (module base) tag. if this tag exists, we will use + # it to rebase the underlying database before continuing to parse + # and process the trace data + # + + first_line = f.readline() + if first_line.startswith("mb="): + try: + self.module_base = int(first_line.split("=")[1], 16) + if not rebase_database(self.module_base): + pmsg("Failed to rebase database, trace may not align correctly.") + except (ValueError, IndexError): + pmsg("Failed to parse module base from trace file header.") + # The rest of the file is the actual trace content + remaining_lines = f.readlines() + else: + # The first line was not a rebase line, so we process it with the rest. + remaining_lines = [first_line] + f.readlines() + + # + # now we process the trace lines in segments + # + + for i in range(0, len(remaining_lines), self.segment_length): + lines = remaining_lines[i:i+self.segment_length] + if not lines: + break + + segment_id = len(self.segments) + + # create a new trace segment from the given lines of text + segment = TraceSegment(self, segment_id, idx) + segment.from_lines(lines) + idx += segment.length + + # save the segment + self.segments.append(segment) + + self._finalize() + self._save() + + def get_ip(self, idx): + """ + Return the fully qualified IP for the given timestamp. + """ + seg = self.get_segment(idx) + if not seg: + raise ValueError("Invalid IDX %u" % idx) + return seg.get_ip(idx) + + def get_mapped_ip(self, ip): + """ + Return the 'mapped' (compressed) id for the given instruction address. + """ + index = bisect.bisect_left(self.ip_addrs, ip) + + try: + if ip == self.ip_addrs[index]: + return index + except IndexError: + pass + + raise ValueError(f"Address {ip:08X} does not have a mapped ID") + + # + # TODO: note, uh.. these should all be refactored... gross + # + + def get_aligned_address(self, address): + return (address >> 3) << 3 + + def get_mapped_address(self, address): + """ + Return the 'mapped' (compressed) id for the given memory address. + """ + + # + # TODO: use pointer size/alignment?? eg, this might make mem lookups faster + # if we tune it to 32bit vs 64bit (at the cost of possible trace size inflation) + # + + aligned_address = (address >> 3) << 3 + index = bisect.bisect_left(self.mem_addrs, aligned_address) + + if index == len(self.mem_addrs): + return -1 + + if aligned_address != self.mem_addrs[index]: + return -1 + + return index + + def get_aligned_address_mask(self, address, length=8): + """ + TODO: ugh hopefully we'll have a native backend before i have to try + and write a comment to describe the mess we're in + """ + mask_offset = address % 8 + aligned_address = ((address >> 3) << 3) + aligned_mask = (((1 << length) - 1) << mask_offset) & 0xFF + return aligned_mask + + def _finalize(self): + """ + Bake a parsed text trace into its final, compressed form. + """ + + if TRACE_STATS: + self._init_stats() + + # a 32 or 64 bit array.array type code, depending on the trace arch pointer size + pointer_type = type_from_width(self.arch.POINTER_SIZE) + + # bake the master ip address table + ip_map = self.ip_map + ip_addrs = sorted(list(self.ip_map.keys())) + self.ip_addrs = array.array(pointer_type, ip_addrs) + + remapped_ip = { + ip_map[address]: i for i, address in enumerate(ip_addrs) + } + + # bake the master (aligned) memory address table + mem_map = self.mem_map + mem_map_len = len(mem_map) + mem_addrs = sorted(list(mem_map.keys())) + self.mem_addrs = array.array(pointer_type, mem_addrs) + self.mem_masks = array.array('B', [0] * len(mem_addrs)) + + # generate a temporary mem re-mapping map... + remapped_mem = { + mem_map[address]: i for i, address in enumerate(mem_addrs) + } + + # pre-compute the 'size' of the data represented by a register mask + self.mask_sizes = [number_of_bits_set(mask) * self.arch.POINTER_SIZE for mask in self.masks] + + assert self.segment_length <= UINT_MAX + assert mem_map_len <= UINT_MAX + + self.mem_idx_type = type_from_limit(self.segment_length) + self.mem_addr_type = type_from_limit(mem_map_len) + + # finish packing the trace + for segment in self.segments: + segment.finalize(remapped_ip, remapped_mem) + + if TRACE_STATS: + self._harvest_stats(segment) + + # dispose of stuff we don't need anymore + del self.ip_map + del self.mem_map + del remapped_mem + + if TRACE_STATS: + self._finalize_stats() + + #------------------------------------------------------------------------- + # Trace Statistics + #------------------------------------------------------------------------- + + def _init_stats(self): + self.unique_mem_addr = set() + self.avg_unique_mem_addr = 0 + self.min_unique_mem_addr = 999999999 + self.max_unique_mem_addr = -1 + + self.avg_unique_ip = 0 + self.min_unique_ip = 999999999999 + self.max_unique_ip = -1 + + self.num_bytes_read = 0 + self.num_bytes_written = 0 + self.num_bytes_read_info = 0 + self.num_bytes_written_info = 0 + + self.num_bytes_ips = 0 + self.num_bytes_reg_data = 0 + self.num_bytes_reg_masks = 0 + + self.raw_size = 0 + + self.unique_ip = 0 + self.num_bytes_unique_ip = 0 + + self.unique_mem = 0 + self.num_bytes_unique_mem = 0 + + def _harvest_stats(self, seg): + unique_mem_addr = seg.read_addresses | seg.write_addresses + self.unique_mem_addr |= unique_mem_addr + + num_unique_mem_addr = len(unique_mem_addr) + self.avg_unique_mem_addr += num_unique_mem_addr + self.min_unique_mem_addr = min(self.min_unique_mem_addr, num_unique_mem_addr) + self.max_unique_mem_addr = max(self.max_unique_mem_addr, num_unique_mem_addr) + + num_unique_ip = seg.num_unique_ip + self.avg_unique_ip += num_unique_ip + self.min_unique_ip = min(self.min_unique_ip, num_unique_ip) + self.max_unique_ip = max(self.max_unique_ip, num_unique_ip) + + self.num_bytes_read += seg.num_bytes_read + self.num_bytes_written += seg.num_bytes_written + self.num_bytes_read_info += seg.num_bytes_read_info + self.num_bytes_written_info += seg.num_bytes_written_info + + self.num_bytes_ips += seg.num_bytes_ips + self.num_bytes_reg_data += seg.num_bytes_reg_data + self.num_bytes_reg_masks += seg.num_bytes_reg_masks + + self.raw_size += seg.raw_size_bytes + #self.length += seg.length + + def _finalize_stats(self): + self.avg_unique_ip = self.avg_unique_ip // len(self.segments) + self.avg_unique_mem_addr = self.avg_unique_mem_addr // len(self.segments) + + self.unique_ip = len(self.ip_addrs) + self.num_bytes_unique_ip = len(self.ip_addrs) * self.arch.POINTER_SIZE + self.raw_size += self.num_bytes_unique_ip + + self.unique_mem = len(self.mem_addrs) + self.num_bytes_unique_mem = len(self.mem_addrs) * self.arch.POINTER_SIZE + self.raw_size += self.num_bytes_unique_mem + + def print_stats(self): + output = [] + output.append(f"+- Trace Stats") + output.append("") + output.append(f" -- {self.length:,} timestamps") + output.append(f" -- {len(self.segments):,} segments") + output.append("") + output.append(f" - Address Stats") + output.append("") + output.append(f" -- {self.unique_ip:,} total unique ip addresses") + output.append(f" ---- {self.avg_unique_ip} avg") + output.append(f" ---- {self.min_unique_ip} min") + output.append(f" ---- {self.max_unique_ip} max") + output.append("") + output.append(f" -- {len(self.unique_mem_addr):,} total unique mem addresses") + output.append(f" ---- {self.avg_unique_mem_addr} avg") + output.append(f" ---- {self.min_unique_mem_addr} min") + output.append(f" ---- {self.max_unique_mem_addr} max") + output.append("") + output.append(f" - Memory / Disk Footprint") + output.append("") + output.append(f" -- {self.raw_size/(1024*1024):0.2f}mb - raw size") + output.append("") + output.append(f" ---- {self.num_bytes_unique_ip / (1024*1024):0.2f}mb ({(self.num_bytes_unique_ip / self.raw_size) * 100 :3.2f}%) - ip addrs") + output.append(f" ---- {self.num_bytes_unique_mem / (1024*1024):0.2f}mb ({(self.num_bytes_unique_mem / self.raw_size) * 100 :3.2f}%) - mem addrs") + output.append(f" ---- {self.num_bytes_ips / (1024*1024):0.2f}mb ({(self.num_bytes_ips / self.raw_size) * 100 :3.2f}%) - ip trace") + output.append(f" ---- {self.num_bytes_reg_data / (1024*1024):0.2f}mb ({(self.num_bytes_reg_data / self.raw_size) * 100 :3.2f}%) - reg data") + output.append(f" ---- {self.num_bytes_reg_masks / (1024*1024):0.2f}mb ({(self.num_bytes_reg_masks / self.raw_size) * 100 :3.2f}%) - reg masks") + output.append("") + output.append(f" ---- {self.num_bytes_read / (1024*1024):0.2f}mb ({(self.num_bytes_read / self.raw_size) * 100 :3.2f}%) - bytes read") + output.append(f" ---- {self.num_bytes_written / (1024*1024):0.2f}mb ({(self.num_bytes_written / self.raw_size) * 100 :3.2f}%) - bytes written") + output.append(f" ---- {self.num_bytes_read_info / (1024*1024):0.2f}mb ({(self.num_bytes_read_info / self.raw_size) * 100 :3.2f}%) - read pointers") + output.append(f" ---- {self.num_bytes_written_info / (1024*1024):0.2f}mb ({(self.num_bytes_written_info / self.raw_size) * 100 :3.2f}%) - write pointers") + print(''.join(output)) + +class TraceSegment(object): + """ + A segment of trace data. + """ + + def __init__(self, trace, id=0, base_idx=0): + self.id = id + self.arch = trace.arch + self.trace = trace + + self.base_idx = base_idx + self.length = 0 + + self.reg_data = None + self.reg_masks = None + + self.read_data = None + self.read_idxs = None + self.read_addrs = None + self.read_masks = None + self.read_offsets = [] + + self.write_data = None + self.write_idxs = None + self.write_addrs = None + self.write_masks = None + self.write_offsets = [] + + self.mem_delta = collections.defaultdict(MemValue) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def read_set(self): + return set(self.read_addrs) + + @property + def write_set(self): + return set(self.write_addrs) + + @property + def num_unique_ip(self): + return len(set(self.ips)) + + @property + def num_unique_mem_addresses(self): + return len(self.read_set | self.write_set) + + @property + def num_bytes_read(self): + return len(self.read_data) + + @property + def num_bytes_written(self): + return len(self.write_data) + + #@property + #def num_bytes_read_info(self): + # return ctypes.sizeof(self._mem_read_info) + + #@property + #def num_bytes_written_info(self): + # return ctypes.sizeof(self._mem_write_info) + + @property + def num_bytes_reg_data(self): + return len(self.reg_data) + + @property + def num_bytes_ips(self): + return ctypes.sizeof(self.ips) + + @property + def num_bytes_reg_masks(self): + return ctypes.sizeof(self.reg_masks) + + @property + def raw_size_bytes(self): + size = 0 + + # reg data storage costs + size += self.num_bytes_ips + size += self.num_bytes_reg_data + size += self.num_bytes_reg_masks + + # memory data storage costs + size += self.num_bytes_read + size += self.num_bytes_written + size += self.num_bytes_read_info + size += self.num_bytes_written_info + + return size + + @property + def raw_size_mb(self): + return self.raw_size_bytes / (1024*1024) + + def __str__(self): + output = [] + output.append(f"Trace Segment -- IDX {self.base_idx}") + output.append(f" -- Reg Data {len(self.reg_data)} bytes ({len(self.reg_data) / (1024*1024):0.2f}mb)") + output.append(f" -- Unique IP {len(set(self.ips))}") + output.append(f" -- Raw Size {self.raw_size_mb:0.2f}mb") + return ''.join(output) + + #------------------------------------------------------------------------- + # Public + #------------------------------------------------------------------------- + + def from_lines(self, lines): + """ + Load a trace segment from the given lines. + """ + + # ip storage + self.ips = [0 for x in range(self.trace.segment_length)] + + # register storage (minus IP) + MAX_REG_DATA = self.trace.arch.POINTER_SIZE * len(self.trace.arch.REGISTERS) * self.trace.segment_length + self.reg_data = bytearray(MAX_REG_DATA) + self.reg_offsets = array.array("I", [0] * REG_OFFSET_CACHE_SIZE) + self.reg_masks = [0 for x in range(self.trace.segment_length)] + self._reg_offset = 0 + + # memory defs + self._mem_read_info = [] + self.read_data = bytearray() + self._mem_write_info = [] + self.write_data = bytearray() + self._max_read_size = 0 + self._max_write_size = 0 + + self._process_lines(lines) + #print(f"Snapshot entries: {len(self.mem_delta)}") + + def from_file(self, f): + """ + Load the trace segment from the given filestream. + """ + self.load(f) + + def get_ip(self, idx): + """ + Return the IP for the given timestamp. + """ + relative_idx = idx - self.base_idx + return self.trace.ip_addrs[self.ips[relative_idx]] + + def get_reg_delta(self, idx): + """ + Return the register delta for the given timestamp. + """ + relative_idx = idx - self.base_idx + + # IP is the only register guaranteed to have changed each step + ip_address = self.trace.ip_addrs[self.ips[relative_idx]] + + # fetch the mask that tells which registers have changed this delta + mask = self.trace.masks[self.reg_masks[relative_idx]] + + # if no registers changed, nothing to do but return IP + if not mask: + return {self.trace.arch.IP: ip_address} + + # + # fetch the closest cached register data offset that we can start from + # for computing precisely where we should be working backwards from + # + + cache_index = int(relative_idx / REG_OFFSET_CACHE_INTERVAL) + cache_offset = self.reg_offsets[cache_index] + cache_idx = cache_index * REG_OFFSET_CACHE_INTERVAL + + # compute the current 'offset' in the reg data that we will work back from + sizes = self.trace.mask_sizes + offset_masks = self.reg_masks[cache_idx:relative_idx][::-1] + offset = cache_offset + sum([sizes[mask_id] for mask_id in offset_masks]) + + # compute the location of the packed register delta data + #offset_slow = sum([sizes[mask_id] for mask_id in self.reg_masks[:relative_idx]]) + #assert offset == offset_slow + + # fetch the register data + reg_names = self._mask2regs(mask) + num_regs = len(reg_names) + reg_data = self.reg_data[offset:offset + (num_regs * self.arch.POINTER_SIZE)] + + # unpack the register data + pack_fmt = 'Q' if self.arch.POINTER_SIZE == 8 else 'I' + reg_values = struct.unpack(pack_fmt * num_regs, reg_data) + + # pack all the registers into a dict that will be returned to the user + registers = dict(zip(reg_names, reg_values)) + registers[self.trace.arch.IP] = ip_address + + # return the completed register delta + return registers + + # + # TODO: ugh some of this stuff is pretty gross too, is it even used still...? + # + + def get_read_delta(self, idx): + """ + Return the memory read delta for the given timestamp. + """ + return self._get_mem_delta(idx, TRACE_MEM_READ) + + def get_write_delta(self, idx): + """ + Return the memory write delta for the given timestamp. + """ + return self._get_mem_delta(idx, TRACE_MEM_WRITE) + + def _get_mem_delta(self, idx, mem_type): + """ + Internal abstraction to search memory delta lists. + """ + relative_idx = idx - self.base_idx + found, offset = [], 0 + + if mem_type == TRACE_MEM_WRITE: + idxs, addrs, masks, offsets, data = self.write_idxs, self.write_addrs, self.write_masks, self.write_offsets, self.write_data + else: + idxs, addrs, masks, offsets, data = self.read_idxs, self.read_addrs, self.read_masks, self.read_offsets, self.read_data + + try: + i = idxs.index(relative_idx) + except ValueError: + return [] + + while i < len(idxs) and idxs[i] == relative_idx: + + # + # fetch the 'aligned' address for this memory access, and the + # mask which specifes which bytes were touched starting from + # the aligned address + # + + aligned_address = self.trace.mem_addrs[addrs[i]] + access_mask = masks[i] + + # extract the raw data for this memory access + offset = offsets[i] + length = number_of_bits_set(masks[i]) + raw_data = data[offset:offset+length] + + address = aligned_address + seen_byte = False # TODO KLUDGE + while access_mask: + if access_mask & 1 == 0: + address += 1 + assert not seen_byte, "gap in memory access?" + else: + seen_byte = True + access_mask >>= 1 + + found.append((address, raw_data)) + i += 1 + + # return all the hits + return found + + def get_reg_info(self, idx, reg_names): + """ + Given a starting timestamp and a list of register names, return + + { reg_name: (value, idx) } + + ... for each discoverable register in this segment. + + """ + relative_idx = idx - self.base_idx + start_idx = relative_idx + 1 + if not (0 <= relative_idx < self.length): + return {} + + # compute a 32bit mask of the registers we need to find + target_mask = self._regs2mask(reg_names) + + # + # fetch the closest cached register data offset that we can start from + # for computing precisely where we should be working backwards from + # + + cache_index = int(start_idx / REG_OFFSET_CACHE_INTERVAL) + cache_offset = self.reg_offsets[cache_index] + cache_idx = cache_index * REG_OFFSET_CACHE_INTERVAL + + # alias for faster access / readability + sizes = self.trace.mask_sizes + masks = self.trace.masks + + # compute the current 'offset' in the reg data that we will work back from + offset_masks = self.reg_masks[cache_idx:start_idx][::-1] + offset = cache_offset + sum([sizes[mask_id] for mask_id in offset_masks]) + + # the map of reg_name --> (reg_value, src_idx) to return + found_registers = {} + + # loop backwards through the segment, starting from the given idx + search_masks = self.reg_masks[:start_idx][::-1] + #offset_slow = sum([sizes[mask_id] for mask_id in search_masks]) + #assert offset == offset_slow + for i, mask_id in enumerate(search_masks): + + # translate the mask id for this step into its register bitfield + current_mask = masks[mask_id] + + # + # since we are iterating backwards through the register data, we + # need to subtract from the offset immediately as it is pointing + # at the end of the register data for this mask. + # + + offset -= sizes[mask_id] + + # ignore masks that do not touch the target registers + if not current_mask & target_mask: + continue + + # translate the 32bit reg mask into a list of register names + found_mask = current_mask & target_mask + found_names = self._mask2regs(found_mask) + + # fetch the registers for this delta / timestamp + registers = self._unpack_registers(current_mask, offset) + + # add the found register names and the current (global) idx + for reg_name in found_names: + found_registers[reg_name] = (registers[reg_name], (self.base_idx + (start_idx - i))) + + # remove the registers we found from the remaining search space + target_mask ^= found_mask + + # if target_mask is 0, then there are no more registers to look for + if not target_mask: + break + + return found_registers + + def get_mem_data(self, mem_id, set_id, data_mask): + """ + Return the data for a given mem access id, in the given set. + """ + + if set_id == 1: + addrs, masks, offsets, data = self.write_addrs, self.write_masks, self.write_offsets, self.write_data + else: + addrs, masks, offsets, data = self.read_addrs, self.read_masks, self.read_offsets, self.read_data + + offset = offsets[mem_id] #sum([number_of_bits_set(mask) for mask in masks[:mem_id]]) + #offset = sum([number_of_bits_set(mask) for mask in masks[:mem_id]]) + length = number_of_bits_set(masks[mem_id]) + raw_data = data[offset:offset+length] + + address = self.trace.mem_addrs[addrs[mem_id]] + output = TraceMemory(address, 8) + + byte, i = 0, 0 + + while data_mask: + if data_mask & 1: + output.data[i] = raw_data[byte] + output.mask[i] = 0xFF + byte += 1 + i += 1 + data_mask >>= 1 + + #assert byte == length + + return output + + #------------------------------------------------------------------------- + # Finalization + #------------------------------------------------------------------------- + + def load(self, f): + """ + Load the trace segment from the given filestream. + """ + info = SegmentInfo() + f.readinto(info) + + self.id = info.id + self.base_idx = info.base_idx + self.length = info.length + + if info.ip_num == 0: + raise ValueError("Empty trace file (ip_num == 0)") + + ip_itemsize = info.ip_length // info.ip_num + ip_type = type_from_width(ip_itemsize) + + # load the ip trace + self.ips = array.array(ip_type) + self.ips.fromfile(f, info.ip_num) + + # load the reg mask data + reg_mask_type = type_from_width(info.reg_mask_length // info.reg_mask_num) + self.reg_masks = array.array(reg_mask_type) + self.reg_masks.fromfile(f, info.reg_mask_num) + + # load the reg data + self.reg_data = bytearray(info.reg_data_length) + f.readinto(self.reg_data) + + # load the pre-computed register offsets + self.reg_offsets = array.array("I") + self.reg_offsets.fromfile(f, REG_OFFSET_CACHE_SIZE) + + # + # memory + # + + idx_type = self.trace.mem_idx_type + addr_type = self.trace.mem_addr_type + + # load the memory read metadata + self.read_idxs = array.array(idx_type) + self.read_idxs.fromfile(f, info.mem_read_num) + self.read_addrs = array.array(addr_type) + self.read_addrs.fromfile(f, info.mem_read_num) + self.read_masks = array.array('B') + self.read_masks.fromfile(f, info.mem_read_num) + + # load the raw memory read data + self.read_data = bytearray(info.mem_read_data_length) + f.readinto(self.read_data) + + # load the memory write metadata + self.write_idxs = array.array(idx_type) + self.write_idxs.fromfile(f, info.mem_write_num) + self.write_addrs = array.array(addr_type) + self.write_addrs.fromfile(f, info.mem_write_num) + self.write_masks = array.array('B') + self.write_masks.fromfile(f, info.mem_write_num) + + # load the raw memory write data + self.write_data = bytearray(info.mem_write_data_length) + f.readinto(self.write_data) + + # load the mem delta / 'snapshot' data + addr_set = sorted(set(self.read_addrs + self.write_addrs)) + delta_entries = (MemValue * len(addr_set))() + f.readinto(delta_entries) + + self.mem_delta = dict(zip(addr_set, delta_entries)) + + self._compute_mem_offsets() + + def dump(self, f): + """ + Dump the trace segment to the given filestream. + """ + info = SegmentInfo() + + info.id = self.id + info.base_idx = self.base_idx + info.length = self.length + + info.ip_num = self.length + info.ip_length = info.ip_num * self.ips.itemsize + + info.reg_mask_num = len(self.reg_masks) + info.reg_mask_length = info.reg_mask_num * self.reg_masks.itemsize + info.reg_data_length = len(self.reg_data) # bytearray + + info.mem_read_num = len(self.read_idxs) + info.mem_read_data_length = len(self.read_data) + + info.mem_write_num = len(self.write_idxs) + info.mem_write_data_length = len(self.write_data) + + f.write(bytearray(info)) + f.write(bytearray(self.ips)) + + f.write(bytearray(self.reg_masks)) + f.write(self.reg_data) + f.write(bytearray(self.reg_offsets)) + + self.read_idxs.tofile(f) + self.read_addrs.tofile(f) + self.read_masks.tofile(f) + f.write(self.read_data) + + self.write_idxs.tofile(f) + self.write_addrs.tofile(f) + self.write_masks.tofile(f) + f.write(self.write_data) + + for mapped_address in sorted(set(self.read_addrs + self.write_addrs)): + f.write(bytearray(self.mem_delta[mapped_address])) + + #------------------------------------------------------------------------- + # Finalization + #------------------------------------------------------------------------- + + def finalize(self, remapped_ip, remapped_mem): + """ + Bake the trace segment into its final, packed form. + """ + self._finalize_registers(remapped_ip) + self._finalize_memory(remapped_mem) + + def _finalize_registers(self, remapped_ip): + """ + Bake registers into ctype structures. + """ + assert len(remapped_ip) <= UINT_MAX + assert len(self.trace.mask2mapped) <= USHRT_MAX + + # + # pack IP trace + # + + ip_type = type_from_limit(len(remapped_ip)) + new_ips = array.array(ip_type, [0] * len(self.ips)) + + for i, mapped_ip in enumerate(self.ips): + new_ips[i] = remapped_ip[mapped_ip] + + del self.ips + self.ips = new_ips + + # + # pack register masks + # + + mask_type = type_from_limit(len(self.trace.mask2mapped)) + new_masks = array.array(mask_type, self.reg_masks) + + del self.reg_masks + self.reg_masks = new_masks + + def _finalize_memory(self, remapped_mem): + """ + Bake memory into ctype structures. + """ + idx_type = self.trace.mem_idx_type + addr_type = self.trace.mem_addr_type + + # + # pack read data + # + + # allocate fast, compact python arrays to hold our mem read info + read_idxs = array.array(idx_type) + read_addrs = array.array(addr_type) + read_masks = array.array('B') + + # transfer read metadata into compact / searchable arrays + for entry in self._mem_read_info: + idx, old_mapped_address, mask = entry + + # convert the old mapped address to a new mapped address + mapped_address = remapped_mem[old_mapped_address] + + # pack the data into fast / compact python arrays + read_idxs.append(idx) + read_addrs.append(mapped_address) + read_masks.append(mask) + + del self._mem_read_info + self.read_idxs = read_idxs + self.read_addrs = read_addrs + self.read_masks = read_masks + + # + # pack write data + # + + # allocate fast, compact python arrays to hold our mem write info + write_idxs = array.array(idx_type) + write_addrs = array.array(addr_type) + write_masks = array.array('B') + + # transfer write metadata into compact / searchable arrays + for entry in self._mem_write_info: + idx, old_mapped_address, mask = entry + + # convert the old mapped address to a new mapped address + mapped_address = remapped_mem[old_mapped_address] + + # pack the data into fast / compact python arrays + write_idxs.append(idx) + write_addrs.append(mapped_address) + write_masks.append(mask) + + del self._mem_write_info + self.write_idxs = write_idxs + self.write_addrs = write_addrs + self.write_masks = write_masks + + # + # build trace mask + # + + new_delta = {} + mem_masks = self.trace.mem_masks + + for old_mapped_address, mv in self.mem_delta.items(): + mapped_address = remapped_mem[old_mapped_address] + new_delta[mapped_address] = mv + mem_masks[mapped_address] |= mv.mask + + del self.mem_delta + self.mem_delta = new_delta + + self._compute_mem_offsets() + + def _compute_mem_offsets(self): + """ + Pre-compute the offset of each memory access into the raw memory blobs. + """ + temp_sizes = {} + + self.read_offsets = array.array('I', [0] * len(self.read_masks)) + self.write_offsets = array.array('I', [0] * len(self.write_masks)) + + mem_sets = [ + (self.read_offsets, self.read_masks), + (self.write_offsets, self.write_masks) + ] + + for offsets, masks in mem_sets: + offset = 0 + for i, mask in enumerate(masks): + offsets[i] = offset + length = temp_sizes.setdefault(mask, number_of_bits_set(mask)) + offset += length + + #------------------------------------------------------------------------- + # Processing / Logic + #------------------------------------------------------------------------- + + def _process_lines(self, lines): + """ + Process text lines from a delta reg/mem trace. + """ + IP = self.trace.arch.IP + REGISTERS = self.trace.arch.REGISTERS + + relative_idx = 0 + + try: + + for line in lines: + if not self._process_line(line, relative_idx): + continue + relative_idx += 1 + + # TODO: pretty gross, but let's just wrap it to make these issues more apparents + except Exception as e: + pmsg(f"LINE PARSE FAILED, line ~{self.base_idx+relative_idx:,}, contents '{line}'") + pmsg(str(e)) + + self.reg_data = bytearray(self.reg_data[:self._reg_offset]) + self.ips = self.ips[:relative_idx] + self.length = relative_idx + + def _process_line(self, line, relative_idx): + """ + Process one line of text from a delta reg/mem trace. + """ + IP = self.trace.arch.IP + REGISTERS = self.trace.arch.REGISTERS + + delta = line.split(",") + registers = {} + + # split the state info (registers, memory) into individual items to process + for item in delta: + name, value = item.split("=") + name = name.upper() + + # special compression of IP + if name == IP: + ip = int(value, 16) + + try: + mapped_ip = self.trace.ip_map[ip] + + except KeyError: + mapped_ip = len(self.trace.ip_map) + self.trace.ip_map[ip] = mapped_ip + + self.ips[relative_idx] = mapped_ip + + # GPR + elif name in REGISTERS: + registers[name] = int(value, 16) + + # handle memory r/w/rw access + elif name in ["MR", "MW", "MRW"]: + + # + # a single line can contain multiple memory entries of the same + # type. they will be delimited by a ';' + # + # eg: mr=ADDRESS:DATA;ADDRESS:DATA;... + # + + for entry in value.split(';'): + address, hex_data = entry.split(":") + address = int(address, 16) + hex_data = bytes(hex_data.strip(), 'utf-8') + data = binascii.unhexlify(hex_data) + self._process_mem_entry(address, data, name, relative_idx) + + else: + raise ValueError(f"Invalid line in text trace! '{line}' error on '{name}', (value '{value}')") + + self._pack_registers(registers, relative_idx) + + return True + + def _process_mem_entry(self, address, data, access_type, relative_idx): + """ + TODO + """ + + byte = 0 + for mapped_address, access_mask, access_data in self._map_mem_access(address, data): + + # read + if access_type == 'MR': + + self._mem_read_info.append((relative_idx, mapped_address, access_mask)) + self.read_data += access_data + #self._max_read_size = max(self._max_read_size, data_len) + + # write + elif access_type == 'MW': + self._mem_write_info.append((relative_idx, mapped_address, access_mask)) + self.write_data += access_data + #print(self._mem_write_info[-1], hexdump(data), "REAL OFFSET", len(self.write_data)-len(data)) + #self._max_write_size = max(self._max_write_size, data_len) + + # read AND write (eg, inc [rax]) + elif access_type == 'MRW': + + # read + self._mem_read_info.append((relative_idx, mapped_address, access_mask)) + self.read_data += access_data + #self._max_read_size = max(self._max_read_size, data_len) + + # write + self._mem_write_info.append((relative_idx, mapped_address, access_mask)) + self.write_data += access_data + #self._max_write_size = max(self._max_write_size, data_len) + + else: + raise ValueError("Unknown field in trace: '%s=...'" % access_type) + + mv = self.mem_delta[mapped_address] + mv.mask |= access_mask + #print(f"ADDRESS: 0x{address:08X} MASK: {access_mask:02X}") + + # snapshot stuff + bit, byte = 0, 0 + while access_mask: + if access_mask & 1: + #print(bit, byte) + mv.value[bit] = access_data[byte] + #byte_shift = (bit * 8) + #byte_mask = 0xFF << byte_shift + #value[0] = (value[0] & ~byte_mask) | (data[byte] << byte_shift) + byte += 1 + access_mask >>= 1 + bit += 1 + + def _map_mem_access(self, address, data): + """ + TODO: lol welcome to hell :^) + """ + output = [] + data_len = len(data) + access_data = data + + mask_offset = address % 8 + remaining_mask = ((1 << data_len) - 1) << mask_offset + aligned_address = ((address >> 3) << 3) + access_length = min(len(access_data), (8 - mask_offset)) + + while remaining_mask: + + aligned_mask = remaining_mask & 0xFF + + mapped_address = self.trace.mem_map.setdefault(aligned_address, len(self.trace.mem_map)) + + output.append((mapped_address, aligned_mask, access_data[:access_length])) + access_data = access_data[access_length:] + + remaining_mask >>= 8 + aligned_address += 8 + access_length = min(len(access_data), 8) + + return output + + def _pack_registers(self, registers, relative_idx): + """ + Compress a register delta. + """ + num_regs = len(registers) + + # + # to help improve the speed of looking up register values in the data + # blob, we cache pre-computed offsets at finxed intervals throughout + # the segment. + # + # at query time, we can pick the closest cached interval prior to the + # target idx and only re-compute a fraction of the offsets needed to + # find the correct offset into the data blob to fetch our reg delta + # + + if not(relative_idx % REG_OFFSET_CACHE_INTERVAL): + cache_index = int(relative_idx / REG_OFFSET_CACHE_INTERVAL) + #print(f"rIDX: {relative_idx:,} CACHE: {cache_index} LEN: {len(self.reg_offsets)}") + self.reg_offsets[cache_index] = self._reg_offset + + # + # XXX/TODO: BODGE FOR WHEN PEOPLE DON'T DUMP A FULL REGISTER STATE + # + + if self.base_idx == 0 and self._reg_offset == 0: + if num_regs != len(self.arch.REGISTERS): + for reg_name in self.arch.REGISTERS: + if reg_name not in registers: + if reg_name == self.arch.IP: + continue + pmsg(f"MISSING INITIAL REGISTER VALUE FOR {reg_name}") + registers[reg_name] = 0 + num_regs += 1 + + mask = self._regs2mask(registers.keys()) + + try: + mapped_mask = self.trace.mask2mapped[mask] + except KeyError: + mapped_mask = len(self.trace.mask2mapped) + self.trace.mask2mapped[mask] = mapped_mask + self.trace.masks.append(mask) + + self.reg_masks[relative_idx] = mapped_mask + + value_pairs = sorted([(self.arch.REGISTERS.index(name), value) for name, value in registers.items()]) + values = [x[1] for x in value_pairs] + pack_fmt = 'Q' if self.arch.POINTER_SIZE == 8 else 'I' + struct.pack_into(pack_fmt * num_regs, self.reg_data, self._reg_offset, *values) + self._reg_offset += num_regs * self.arch.POINTER_SIZE + + def _unpack_registers(self, mask, offset): + """ + Unpack register data from the register buffer. + """ + reg_names = self._mask2regs(mask) + + # fetch the register data + num_regs = len(reg_names) + reg_data = self.reg_data[offset:offset + (num_regs * self.arch.POINTER_SIZE)] + + # unpack the register data + pack_fmt = 'Q' if self.arch.POINTER_SIZE == 8 else 'I' + reg_values = struct.unpack(pack_fmt * num_regs, reg_data) + + # pack all the registers into a dict that will be returned to the user + registers = dict(zip(reg_names, reg_values)) + + # return the completed register delta + return registers + + #------------------------------------------------------------------------- + # Util + #------------------------------------------------------------------------- + + def _regs2mask(self, regs): + """ + Convert a list of register names to a register mask. + """ + mask = 0 + for reg in regs: + reg_bit_index = self.arch.REGISTERS.index(reg) + mask |= 1 << reg_bit_index + return mask + + def _mask2regs(self, mask): + """ + Convert a register mask to a list of register names. + """ + regs, bit_index = [], 0 + while mask: + if mask & 1: + regs.append(self.arch.REGISTERS[bit_index]) + mask >>= 1 + bit_index += 1 + return regs diff --git a/plugins_sogen-support/tenet/trace/reader.py b/plugins_sogen-support/tenet/trace/reader.py new file mode 100644 index 0000000..90ef9d6 --- /dev/null +++ b/plugins_sogen-support/tenet/trace/reader.py @@ -0,0 +1,1947 @@ +import bisect +import struct +import logging + +from tenet.types import BreakpointType +from tenet.util.log import pmsg +from tenet.util.misc import register_callback, notify_callback +from tenet.trace.file import TraceFile +from tenet.trace.types import TraceMemory +from tenet.trace.analysis import TraceAnalysis + +logger = logging.getLogger("Tenet.Trace.Reader") + +#----------------------------------------------------------------------------- +# reader.py -- Trace Reader +#----------------------------------------------------------------------------- +# +# NOTE/PREFACE: If you have not already, please read through the overview +# comment at the start of the TraceFile (file.py) code. This file (the +# Trace Reader) builds directly ontop of trace files. +# +# -------------- +# +# This file contains the 'trace reader' implementation for the plugin. It +# is responsible for the navigating a loaded trace file, providing 'high +# level' APIs one might expect to 'efficiently' query a program for +# registers or memory at any timestamp of execution. +# +# Please be mindful that like the TraceFile implementation, TraceReader +# should be re-written entirely in a native language. Under the hood, it's +# not exactly pretty. It was written to make the plugin simple to install +# and experience as a prototype. It is not equipped to adequately scale to +# real world targets. +# +# The most important takeaway from this file should be interface / API +# that it exposes to the plugin. A performant, native TraceReader that +# exposes the same API would be enough to scale the plugin's ability to +# navigate traces that span tens of billions (... maybe even hundreds of +# billions) of instructions. +# + +class TraceDelta(object): + """ + Trace Delta + """ + + def __init__(self, registers, mem_read, mem_write): + self.registers = registers + self.mem_reads = mem_read + self.mem_writes = mem_write + +class TraceReader(object): + """ + A high level, debugger-like interface for querying Tenet traces. + """ + + def __init__(self, filepath, architecture, dctx=None): + self.idx = 0 + self.dctx = dctx + self.arch = architecture + + # load the given trace file from disk + self.trace = TraceFile(filepath, architecture) + self.analysis = TraceAnalysis(self.trace, dctx) + + self._idx_cached_registers = -1 + self._cached_registers = {} + + #---------------------------------------------------------------------- + # Callbacks + #---------------------------------------------------------------------- + + self._idx_changed_callbacks = [] + + #------------------------------------------------------------------------- + # Trace Properties + #------------------------------------------------------------------------- + + @property + def ip(self): + """ + Return the current instruction pointer. + """ + return self.get_register(self.arch.IP) + + @property + def rebased_ip(self): + """ + Return a rebased version of the current instruction pointer (if available). + """ + return self.analysis.rebase_pointer(self.ip) + + @property + def sp(self): + """ + Return the current stack pointer. + """ + return self.get_register(self.arch.SP) + + @property + def registers(self): + """ + Return the current registers. + """ + return self.get_registers() + + @property + def segment(self): + """ + Return the current trace segment. + """ + return self.trace.get_segment(self.idx) + + @property + def delta(self): + """ + Return the state delta since the previous timestamp. + """ + read_set, write_set = set(), set() + + for address, data in self.trace.get_read_delta(self.idx): + read_set |= {address + i for i in range(len(data))} + + for address, data in self.trace.get_write_delta(self.idx): + write_set |= {address + i for i in range(len(data))} + + regs = self.trace.get_reg_delta(self.idx) + + return TraceDelta(regs, read_set, write_set) + + #------------------------------------------------------------------------- + # Trace Navigation + #------------------------------------------------------------------------- + + def seek(self, idx): + """ + Seek the trace to the given timestamp. + """ + + # clamp the index if it goes past the end of the trace + if idx >= self.trace.length: + idx = self.trace.length - 1 + elif idx < 0: + idx = 0 + + # save the new position + self.idx = idx + self.get_registers() + self._notify_idx_changed() + + def seek_percent(self, percent): + """ + Seek to an approximate percentage into the trace. + """ + target_idx = int(self.trace.length * (percent / 100)) + self.seek(target_idx) + + def seek_to_first(self, address, access_type, length=1): + """ + Seek to the first instance of the given breakpoint. + + Returns True on success, False otherwise. + """ + return self.seek_to_next(address, access_type, length, 0) + + def seek_to_final(self, address, access_type, length=1): + """ + Seek to the final instance of the given breakpoint. + + Returns True on success, False otherwise. + """ + return self.seek_to_prev(address, access_type, length, self.trace.length-1) + + def seek_to_next(self, address, access_type, length=1, start_idx=None): + """ + Seek to the next instance of the given breakpoint. + + Returns True on success, False otherwise. + """ + if start_idx is None: + start_idx = self.idx + 1 + + if access_type == BreakpointType.EXEC: + + assert length == 1 + idx = self.find_next_execution(address, start_idx) + + elif access_type == BreakpointType.READ: + + if length == 1: + idx = self.find_next_read(address, start_idx) + else: + idx = self.find_next_region_read(address, length, start_idx) + + elif access_type == BreakpointType.WRITE: + + if length == 1: + idx = self.find_next_write(address, start_idx) + else: + idx = self.find_next_region_write(address, length, start_idx) + + elif access_type == BreakpointType.ACCESS: + + if length == 1: + idx = self.find_next_access(address, start_idx) + else: + idx = self.find_next_region_access(address, length, start_idx) + + else: + raise NotImplementedError + + if idx == -1: + return False + + self.seek(idx) + return True + + def seek_to_prev(self, address, access_type, length=1, start_idx=None): + """ + Seek to the previous instance of the given breakpoint. + + Returns True on success, False otherwise. + """ + if start_idx is None: + start_idx = self.idx - 1 + + if access_type == BreakpointType.EXEC: + + assert length == 1 + idx = self.find_prev_execution(address, start_idx) + + elif access_type == BreakpointType.READ: + + if length == 1: + idx = self.find_prev_read(address, start_idx) + else: + idx = self.find_prev_region_read(address, length, start_idx) + + elif access_type == BreakpointType.WRITE: + + if length == 1: + idx = self.find_prev_write(address, start_idx) + else: + idx = self.find_prev_region_write(address, length, start_idx) + + elif access_type == BreakpointType.ACCESS: + + if length == 1: + idx = self.find_prev_access(address, start_idx) + else: + idx = self.find_prev_region_access(address, length, start_idx) + + else: + raise NotImplementedError + + if idx == -1: + return False + + self.seek(idx) + return True + + def step_forward(self, n=1, step_over=False): + """ + Step the trace forward by n steps. + + If step_over=True, and a disassembler context is available to the + trace reader, it will attempt to step over calls while stepping. + """ + if not step_over: + self.seek(self.idx + n) + else: + self._step_over_forward(n) + + def step_backward(self, n=1, step_over=False): + """ + Step the trace backwards. + + If step_over=True, and a disassembler context is available to the + trace reader, it will attempt to step over calls while stepping. + """ + if not step_over: + self.seek(self.idx - n) + else: + self._step_over_backward(n) + + def _step_over_forward(self, n): + """ + Step the trace forward over n instructions / calls. + """ + address = self.get_ip(self.idx) + bin_address = self.analysis.rebase_pointer(address) + + # + # get the address for the linear instruction address after the + # current instruction + # + + bin_next_address = self.dctx.get_next_insn(bin_address) + if bin_next_address == -1: + self.seek(self.idx + 1) + return + + trace_next_address = self.analysis.rebase_pointer(bin_next_address) + + # + # find the next time the instruction after this instruction is + # executed in the trace + # + + next_idx = self.find_next_execution(trace_next_address, self.idx) + + # + # the instruction after the call does not appear in the trace, + # so just fall-back to 'step into' behavior + # + + if next_idx == -1: + self.seek(self.idx + 1) + return + + self.seek(next_idx) + + def _step_over_backward(self, n): + """ + Step the trace backward over n instructions / calls. + """ + address = self.get_ip(self.idx) + bin_address = self.analysis.rebase_pointer(address) + + bin_prev_address = self.dctx.get_prev_insn(bin_address) + + # + # could not get the address of the instruction prior to the current + # one which means we will not be able to decode it / and really are + # not sure what/where the user would be stepping backwards to... + # + # TODO: it's possible to handle this case, but requires a more + # performant backend than the python prototype that powers this + # + + if bin_prev_address == -1: + self.seek(self.idx - 1) + return + + # + # special handling for when the prior instruction appears to be a call + # instruction, this is perhaps the most important 'step over' scenario + # and also pretty tricky to handle... + # + + if self.dctx.is_call_insn(bin_prev_address): + + # get the previous stack pointer address + sp = self.get_register(self.arch.SP, self.idx - 1) + + # attempt to read a pointer off the stack (possibly a ret address) + try: + maybe_ret_address = self.read_pointer(sp, self.idx) + except ValueError: + print("TODO: stack read failed") + maybe_ret_address = None + + # + # if the address off the stack matches the current address, + # we can assume that we just returned from somewhere. + # + # 99% of the time, this will have been from the call insn at + # prev_address, so let's just assume that is the case and + # 'reverse step over' onto that. + # + # NOTE: technically, we can put in more checks and stuff to + # try and ensure this is 'correct' but, step over and reverse + # step over are kind of an imperfect science as is... + # + + if maybe_ret_address != address: + self.seek(self.idx - 1) + return + + trace_prev_address = self.analysis.rebase_pointer(bin_prev_address) + + prev_idx = self.find_prev_execution(trace_prev_address, self.idx) + if prev_idx == -1: + self.seek(self.idx - 1) + return + + self.seek(prev_idx) + + #------------------------------------------------------------------------- + # Timestamp API + #------------------------------------------------------------------------- + + # + # in this section, you will find references to 'resolution'. this is a + # knob that the trace reader uses to fetch 'approximate' results from + # the underlying trace. + # + # for example, a resolution of 1 is the *most* granular request, where + # one can ask the reader to inspect each step of the trace to see if it + # matches a query (eg, 'when was this instruction address executed') + # + # in contrast, a resolution of 10_000 means that any single hit within + # a resolution 'window' is adequate, and the reader should skip to the + # next window to continue fufilling the query. + # + # given a 10 million instruction trace, and a 30px by 1000px image + # buffer to viualize said trace... there is very little reason to fetch + # 100_000 unique timestamps that all fall within one vertical pixel of + # the rendered visualization. + # + # instead, we can search the trace in arbitrary resolution 'windows' of + # roughly 1px (pixel resolution can be calculated based on the length of + # the trace execution vs the length of the viz in pixels) and fetch results + # that will suffice for visual summarization of trace execution + # + + def get_executions(self, address, resolution=1): + """ + Return a list of timestamps (idx) that executed the given address. + """ + return self.get_executions_between(address, 0, self.trace.length, resolution) + + def get_executions_between(self, address, start_idx, end_idx, resolution=1): + """ + Return a list of timestamps (idx) that executed the given address, in the given slice. + """ + assert 0 <= start_idx <= end_idx, f"0 <= {start_idx:,} <= {end_idx:,}" + assert resolution > 0 + + resolution = max(1, resolution) + #logger.debug(f"Fetching executions from {start_idx:,} --> {end_idx:,} (res {resolution:0.2f}, normalized {resolution:0.2f}) for address 0x{address:08X}") + + try: + mapped_address = self.trace.get_mapped_ip(address) + except ValueError: + return [] + + output = [] + idx = max(0, start_idx) + end_idx = min(end_idx, self.trace.length) + + while idx < end_idx: + + # fetch a segment to search forward through + seg = self.trace.get_segment(idx) + seg_base = seg.base_idx + + # clamp the segment end if it extends past our segment + seg_end = min(seg_base + seg.length, end_idx) + #logger.debug(f"Searching seg #{seg.id}, {seg_base:,} --> {seg_end:,}") + + # snip the segment to start from the given global idx + relative_idx = idx - seg_base + seg_ips = seg.ips[relative_idx:] + + while idx < seg_end: + + try: + idx_offset = seg_ips.index(mapped_address) + except ValueError: + idx = seg_end + 1 + break + + # we got a hit within the resolution window, save it + current_idx = idx + idx_offset + output.append(current_idx) + + # now skip to the next resolution window + current_resolution_index = current_idx / resolution + next_resolution_index = current_resolution_index + 1 + next_resolution_target = next_resolution_index * resolution + idx = round(next_resolution_target) + + #print(f"GOT HIT @ {current_idx:,}, skipping to {idx:,} (y = {current_idx/resolution})") + #print(f" - Current resolution index {current_resolution_index}") + #print(f" - Next resolution index {next_resolution_index}") + #print(f" - Next resolution target {next_resolution_target:,}") + + seg_ips = seg.ips[idx-seg_base:] + + #logger.debug(f"Returning hits {output}") + return output + + def get_memory_accesses(self, address, resolution=1): + """ + Return a tuple of lists (read, write) containing timestamps that access a given memory address. + """ + return self.get_memory_accesses_between(address, 0, self.trace.length, resolution) + + + def get_memory_reads_between(self, address, start_idx, end_idx, resolution=1): + """ + Return a list of timestamps that read from a given memory address in the given slice. + """ + reads, _ = self.get_memory_accesses_between(address, start_idx, end_idx, resolution, BreakpointType.READ) + return reads + + def get_memory_writes_between(self, address, start_idx, end_idx, resolution=1): + """ + Return a list of timestamps that write to a given memory address in the given slice. + """ + _, writes = self.get_memory_accesses_between(address, start_idx, end_idx, resolution, BreakpointType.WRITE) + return writes + + def get_memory_accesses_between(self, address, start_idx, end_idx, resolution=1, access_type=BreakpointType.ACCESS): + """ + Return a tuple of lists (read, write) containing timestamps that access a given memory address in the given slice. + """ + assert resolution > 0 + resolution = max(1, resolution) + + #logger.debug(f"MEMORY ACCESSES @ 0x{address:08X} // {start_idx:,} --> {end_idx:,} (rez {resolution:0.2f})") + + mapped_address = self.trace.get_mapped_address(address) + if mapped_address == -1: + return ([], []) + + reads, writes = [], [] + access_mask = self.trace.get_aligned_address_mask(address, 1) + + # clamp the search incase the given params are a bit wonky + idx = max(0, start_idx) + end_idx = min(end_idx, self.trace.length) + assert idx < end_idx + + next_resolution = [idx, idx] + + # search through the trace + while idx < end_idx: + + # fetch a segment to search forward through + seg = self.trace.get_segment(idx) + seg_base = seg.base_idx + + # clamp the segment end if it extends past our segment + seg_end = min(seg_base + seg.length, end_idx) + #logger.debug(f"seg #{seg.id}, {seg.base_idx:,} --> {seg.base_idx+seg.length:,} -- IDX PTR {idx:,}") + + mem_sets = [] + + if access_type & BreakpointType.READ: + mem_sets.append((seg.read_idxs, seg.read_addrs, seg.read_masks, reads)) + if access_type & BreakpointType.WRITE: + mem_sets.append((seg.write_idxs, seg.write_addrs, seg.write_masks, writes)) + + for i, mem_type in enumerate(mem_sets): + idxs, addrs, masks, output = mem_type + + cumulative_index = 0 + current_target = next_resolution[i] + + while current_target < seg_end: + + try: + index = addrs.index(mapped_address) + except ValueError: + break + + cumulative_index += index + current_idx = seg_base + idxs[index] + + # + # there was a hit to the mapped address, which is aligned + # to the arch pointer size... check if the requested addr + # matches the access mask for this mem access entry + # + + if not (masks[cumulative_index] & access_mask): + addrs = addrs[index+1:] + idxs = idxs[index+1:] + cumulative_index += 1 + continue + + #print(f"FOUND ACCESS TO {self.trace.mem_addrs[mapped_address]:08X} (mask {masks[cumulative_index]:02X}), IDX {current_idx:,}") + + # we got a hit within the resolution window, save it + output.append(current_idx) + + # now skip to the next resolution window + current_resolution_index = current_idx / resolution + next_resolution_index = current_resolution_index + 1 + next_resolution_target = next_resolution_index * resolution + current_target = round(next_resolution_target) + #print(f"NEXT TARGET: {current_target:,}") + + # now skip to the next resolution window + skip_index = bisect.bisect_left(idxs, current_target - seg_base) + if skip_index == len(idxs): + break + + addrs = addrs[skip_index:] + idxs = idxs[skip_index:] + + cumulative_index += (skip_index - index) + + next_resolution[i] = current_target + + idx = seg_end + 1 + + return (reads, writes) + + def get_memory_region_reads(self, address, length, resolution=1): + """ + Return a list of timestamps that read from the given memory region. + """ + reads, _ = self.get_memory_region_accesses_between(address, length, 0, self.trace.length, resolution, BreakpointType.READ) + return reads + + def get_memory_region_reads_between(self, address, length, start_idx, end_idx, resolution=1): + """ + Return a list of timestamps that read from the given memory region in the given time slice. + """ + reads, _ = self.get_memory_region_accesses_between(address, length, start_idx, end_idx, resolution, BreakpointType.READ) + return reads + + def get_memory_region_writes(self, address, length, resolution=1): + """ + Return a list of timestamps that write to the given memory region. + """ + _, writes = self.get_memory_region_accesses_between(address, length, 0, self.trace.length, resolution, BreakpointType.WRITE) + return writes + + def get_memory_region_writes_between(self, address, length, start_idx, end_idx, resolution=1): + """ + Return a list of timestamps that write to the given memory region in the given time slice. + """ + _, writes = self.get_memory_region_accesses_between(address, length, start_idx, end_idx, resolution, BreakpointType.WRITE) + return writes + + def get_memory_region_accesses(self, address, length, resolution=1): + """ + Return a tuple of (read, write) containing timestamps that access the given memory region. + """ + return self.get_memory_region_accesses_between(address, length, 0, self.trace.length, resolution) + + def get_memory_region_accesses_between(self, address, length, start_idx, end_idx, resolution=1, access_type=BreakpointType.ACCESS): + """ + Return a tuple of (read, write) containing timestamps that access the given memory region in the given time slice. + """ + assert resolution > 0 + resolution = max(1, resolution) + + #logger.debug(f"REGION ACCESS BETWEEN @ 0x{address:08X} + {length} // {start_idx:,} --> {end_idx:,} (rez {resolution:0.2f})") + + reads, writes = [], [] + targets = self._region_to_targets(address, length) + + # clamp the search incase the given params are a bit wonky + idx = max(0, start_idx) + end_idx = min(end_idx, self.trace.length) + assert idx < end_idx + + starting_resolution_index = int(idx / resolution) + next_resolution = [starting_resolution_index, starting_resolution_index] + + while idx < end_idx: + + # fetch a segment to search forward through + seg = self.trace.get_segment(idx) + seg_base = seg.base_idx + + # clamp the segment end if it extends past our segment + seg_end = min(seg_base + seg.length, end_idx) + + #print("-"*50) + #print(f"seg #{seg.id}, {seg.base_idx:,} --> {seg.base_idx+seg.length:,} -- IDX PTR {idx:,}") + + mem_sets = [] + + if access_type & BreakpointType.READ: + mem_sets.append((seg.read_idxs, seg.read_addrs, seg.read_masks, reads)) + if access_type & BreakpointType.WRITE: + mem_sets.append((seg.write_idxs, seg.write_addrs, seg.write_masks, writes)) + + for i, mem_type in enumerate(mem_sets): + idxs, addrs, masks, output = mem_type + hits, first_hit = {}, len(addrs) + resolution_index = next_resolution[i] + + # + # check each 'aligned address' (actually an id #) within the given region to see + # if it appears anywhere in the current segment's memory set + # + + for address_id, address_mask in targets: + + # + # if there is a memory access to the region, we will + # break here and begin processing it + # + + try: + index = addrs.index(address_id) + first_hit = min(index, first_hit) + + # + # no hits for any bytes within this aligned address, + # try the next aligned address within the region + # + + except ValueError: + continue + + hits[address_id] = address_mask + + # + # if we hit this, it means no memory accesses of this + # type (eg, reads) occured to the region of memory in + # this segment. + # + # there's nothing else to process for this memory set, + # so just break and move onto the next set (eg, writes) + # + + if not hits: + continue + + for index in range(first_hit, len(addrs)): + address_id = addrs[index] + target_mask = hits.get(address_id, None) + + if not target_mask: + continue + + #print("CLOSE! DOES MASK MATCH?") + #print(f" TARGET: 0x{self.trace.mem_addrs[address_id]:08X} MASK: {target_mask:02X}") + #print(f" CURRENT: 0x{self.trace.mem_addrs[address_id]:08X} MASK: {masks[index]:02X}") + #print(f" RESULT: {target_mask & masks[index]:02X}") + + # + # got the first hit for this set.. great! save it and + # break to search the next memory set + # + + if target_mask & masks[index]: + hit_idx = seg_base + idxs[index] + hit_resolution_index = int(hit_idx / resolution) + if hit_resolution_index < resolution_index: + continue + output.append(hit_idx) + resolution_index += 1 + + next_resolution[i] = resolution_index + + idx = seg_end + 1 + + return (reads, writes) + + def get_prev_ips(self, n, step_over=False): + """ + Return the previous n executed instruction addresses. + + If step_over=True, and a disassembler context is available to the + trace reader, it will attempt to step over calls while stepping. + """ + + # single step, return (reverse) canonical trace sequence + if not step_over: + start = max(-1, self.idx - 1) + end = max(-1, start - n) + return [self.get_ip(idx) for idx in range(start, end, -1)] + + output = [] + dctx, idx = self.dctx, self.idx + trace_address = self.get_ip(idx) + bin_address = self.analysis.rebase_pointer(trace_address) + + # (reverse) step over any call instructions + while len(output) < n and idx > 0: + + bin_prev_address = dctx.get_prev_insn(bin_address) + did_step_over = False + + # call instruction + if bin_prev_address != -1 and dctx.is_call_insn(bin_prev_address): + + # get the previous stack pointer address + sp = self.get_register(self.arch.SP, idx - 1) + + # attempt to read a pointer off the stack (the old ret address) + try: + maybe_ret_address = self.read_pointer(sp, idx) + except ValueError: + print("TODO: stack read failed") + maybe_ret_address = None + + # + # if the address off the stack matches the current address, + # we can assume that we just returned from somewhere. + # + # 99% of the time, this will have been from the call insn at + # bin_prev_address, so let's just assume that is the case and + # 'reverse step over' onto that. + # + # NOTE: technically, we can put in more checks and stuff to + # try and ensure this is 'correct' but, step over and reverse + # step over are kind of an imperfect science as is... + # + + if maybe_ret_address == trace_address: + trace_prev_address = self.analysis.rebase_pointer(bin_prev_address) + prev_idx = self.find_prev_execution(trace_prev_address, idx) + did_step_over = bool(prev_idx != -1) + + # + # if it doesn't look like we just returned from a call, we + # will just fall back to a linear, step-over backwards. + # + # this code is intended to cover the case where a conditional + # happens to jump onto an instruction immediately after a call, + # which causes the above 'stack inspection' to fail + # + + if not did_step_over: + trace_prev_address = self.analysis.rebase_pointer(bin_prev_address) + prev_idx = self.find_prev_execution(trace_prev_address, idx) + + # + # uh, wow okay we're pretty lost and have no idea if there is + # actually something that can be reverse step-over'd. just revert + # to performing a simple single-step backwards + # + + if prev_idx == -1: + prev_idx = idx - 1 + + trace_prev_address = self.get_ip(prev_idx) + + # no address was returned, so the end of trace was reached + if trace_prev_address == -1: + break + + # save the results and continue looping + output.append(trace_prev_address) + trace_address = trace_prev_address + bin_address = self.analysis.rebase_pointer(trace_address) + idx = prev_idx + + # return the list of addresses to be 'executed' next + return output + + def get_next_ips(self, n, step_over=False): + """ + Return the next N executed instruction addresses. + + If step_over=True, and a disassembler context is available to the + trace reader, it will attempt to step over calls while stepping. + """ + + # single step, return canonical trace sequence + if not step_over: + start = min(self.idx + 1, self.trace.length) + end = min(start + n, self.trace.length) + return [self.get_ip(idx) for idx in range(start, end)] + + output = [] + dctx, idx = self.dctx, self.idx + trace_address = self.get_ip(idx) + bin_address = self.analysis.rebase_pointer(trace_address) + + # step over any call instructions + while len(output) < n and idx < (self.trace.length - 1): + + # + # get the address for the instruction address after the + # current (call) instruction + # + + bin_next_address = dctx.get_next_insn(bin_address) + + # + # find the next time the instruction after this instruction is + # executed in the trace + # + + if bin_next_address != -1: + trace_next_address = self.analysis.rebase_pointer(bin_next_address) + next_idx = self.find_next_execution(trace_next_address, idx) + else: + next_idx = -1 + + # + # the instruction after the call does not appear in the trace, + # so just fall-back to 'step into' behavior + # + + if next_idx == -1: + next_idx = idx + 1 + + # + # get the next address to be executed by the trace, according to + # our stepping behavior + # + + trace_next_address = self.get_ip(next_idx) + + # no address was returned, so the end of trace was reached + if trace_next_address == -1: + break + + # save the results and continue looping + output.append(trace_next_address) + bin_address = self.analysis.rebase_pointer(trace_next_address) + idx = next_idx + + # return the list of addresses to be 'executed' next + return output + + def find_next_execution(self, address, idx=None): + """ + Return the next timestamp to execute the given address. + """ + if idx is None: + idx = self.idx + 1 + + try: + mapped_ip = self.trace.get_mapped_ip(address) + except ValueError: + return -1 + + while idx < self.trace.length: + seg = self.trace.get_segment(idx) + + # slice out and reverse the ips to search through + relative_idx = idx - seg.base_idx + ips = seg.ips[relative_idx:] + + # query for the next instance of our target ip + try: + next_idx = ips.index(mapped_ip) + return idx + next_idx + + # no luck, move backwards to the next segment + except ValueError: + idx = seg.base_idx + seg.length + + # fail, reached start of trace + return -1 + + def find_prev_execution(self, address, idx=None): + """ + Return the previous timestamp to execute the given address. + """ + if idx is None: + idx = self.idx - 1 + + try: + mapped_ip = self.trace.get_mapped_ip(address) + except ValueError: + return -1 + + while idx > -1: + seg = self.trace.get_segment(idx) + + # slice out and reverse the ips to search through + relative_idx = idx - seg.base_idx + ips = seg.ips[:relative_idx][::-1] + + # query for the next instance of our target ip + try: + prev_idx = ips.index(mapped_ip) + return idx - prev_idx - 1 + + # no luck, move backwards to the next segment + except ValueError: + idx = seg.base_idx - 1 + + # fail, reached start of trace + return -1 + + def find_next_read(self, address, idx=None): + """ + Return the next timestamp to read the given memory address. + """ + return self._find_next_mem_op(address, BreakpointType.READ, idx) + + def find_prev_read(self, address, idx=None): + """ + Return the previous timestamp to read the given memory address. + """ + return self._find_prev_mem_op(address, BreakpointType.READ, idx) + + def find_next_write(self, address, idx=None): + """ + Return the next timestamp to write to the given memory address. + """ + return self._find_next_mem_op(address, BreakpointType.WRITE, idx) + + def find_prev_write(self, address, idx=None): + """ + Return the previous timestamp to write to the given memory address. + """ + return self._find_prev_mem_op(address, BreakpointType.WRITE, idx) + + def find_next_access(self, address, idx=None): + """ + Return the next timestamp to access the given memory address. + """ + return self._find_next_mem_op(address, BreakpointType.ACCESS, idx) + + def find_prev_access(self, address, idx=None): + """ + Return the previous timestamp to access the given memory address. + """ + return self._find_prev_mem_op(address, BreakpointType.ACCESS, idx) + + def _find_next_mem_op(self, address, bp_type, idx=None): + """ + Return the next timestamp to read the given memory address. + """ + if idx is None: + idx = self.idx + 1 + + mapped_address = self.trace.get_mapped_address(address) + if mapped_address == -1: + return -1 + + access_mask = self.trace.get_aligned_address_mask(address, 1) + starting_segment = self.trace.get_segment(idx) + + accesses, mem_sets = [], [] + + for seg_id in range(starting_segment.id, len(self.trace.segments)): + seg = self.trace.segments[seg_id] + seg_base = seg.base_idx + + mem_sets.clear() + + if bp_type == BreakpointType.READ: + mem_sets.append((seg.read_idxs, seg.read_addrs, seg.read_masks)) + + if bp_type == BreakpointType.WRITE: + mem_sets.append((seg.write_idxs, seg.write_addrs, seg.write_masks)) + + if bp_type == BreakpointType.ACCESS: + mem_sets.append((seg.read_idxs, seg.read_addrs, seg.read_masks)) + mem_sets.append((seg.write_idxs, seg.write_addrs, seg.write_masks)) + + # loop through the read / write memory sets for this segment + for idxs, addrs, masks in mem_sets: + search_addrs = addrs + + normal_index = 0 + while search_addrs: + + try: + index = search_addrs.index(mapped_address) + normal_index += index + except ValueError: + break + + if masks[normal_index] & access_mask: + + assert addrs[normal_index] == mapped_address + assert masks[normal_index] & access_mask + + # ensure that the memory access occurs on or after the starting idx + hit_idx = seg_base + idxs[normal_index] + if idx <= hit_idx: + accesses.append(seg_base + idxs[normal_index]) + break + + # the hit was no good.. 'step' past it and keep searching + search_addrs = search_addrs[index+1:] + normal_index += 1 + + # + # if there has been a read or a write, select the one that is + # 'closest' to our current idx. there should only be, at most, + # two elements in this list... + # + + if accesses: + return min(accesses, key=lambda x:abs(x-idx)) + + # fail, reached end of trace + return -1 + + def _find_prev_mem_op(self, address, bp_type, idx=None): + """ + Return the previous timestamp to access the given memory address. + """ + if idx is None: + idx = self.idx - 1 + + mapped_address = self.trace.get_mapped_address(address) + if mapped_address == -1: + return -1 + + access_mask = self.trace.get_aligned_address_mask(address, 1) + starting_segment = self.trace.get_segment(idx) + + accesses, mem_sets = [], [] + + for seg_id in range(starting_segment.id, -1, -1): + seg = self.trace.segments[seg_id] + seg_base = seg.base_idx + + mem_sets.clear() + + if bp_type == BreakpointType.READ: + mem_sets.append((seg.read_idxs, seg.read_addrs, seg.read_masks)) + + if bp_type == BreakpointType.WRITE: + mem_sets.append((seg.write_idxs, seg.write_addrs, seg.write_masks)) + + if bp_type == BreakpointType.ACCESS: + mem_sets.append((seg.read_idxs, seg.read_addrs, seg.read_masks)) + mem_sets.append((seg.write_idxs, seg.write_addrs, seg.write_masks)) + + # loop through the read / write memory sets for this segment + for idxs, addrs, masks in mem_sets: + search_addrs = addrs[::-1] + + normal_index = len(search_addrs) - 1 + while search_addrs: + + try: + reverse_index = search_addrs.index(mapped_address) + normal_index -= reverse_index + except ValueError: + break + + if masks[normal_index] & access_mask: + + assert addrs[normal_index] == mapped_address + assert masks[normal_index] & access_mask + + # ensure that the memory access occurs on or before the starting idx + hit_idx = seg_base + idxs[normal_index] + if hit_idx <= idx: + accesses.append(seg_base + idxs[normal_index]) + break + + # the hit was no good.. 'step' past it and keep searching + search_addrs = search_addrs[reverse_index+1:] + normal_index -= 1 + + if accesses: + return min(accesses, key=lambda x:abs(x-idx)) + + # fail, reached start of trace + return -1 + + def find_next_region_read(self, address, length, idx=None): + """ + Return the next timestamp to read from given memory region. + """ + return self._find_next_region_access(address, length, idx, BreakpointType.READ) + + def find_next_region_write(self, address, length, idx=None): + """ + Return the next timestamp to write to the given memory region. + """ + return self._find_next_region_access(address, length, idx, BreakpointType.WRITE) + + def find_next_region_access(self, address, length, idx=None): + """ + Return the next timestamp to access (r/w) the given memory region. + """ + return self._find_next_region_access(address, length, idx, BreakpointType.ACCESS) + + def _find_next_region_access(self, address, length, idx=None, access_type=BreakpointType.ACCESS): + """ + Return the next timestamp to access the given memory region. + """ + if idx is None: + idx = self.idx + 1 + + #logger.debug(f"FIND NEXT REGION ACCESS FOR 0x{address:08X} -> 0x{address+length:08X} STARTING AT IDX {idx:,}") + + accesses, mem_sets = [], [] + targets = self._region_to_targets(address, length) + starting_segment = self.trace.get_segment(idx) + + for seg_id in range(starting_segment.id, len(self.trace.segments)): + + # fetch a segment to search forward through + seg = self.trace.segments[seg_id] + seg_base = seg.base_idx + + mem_sets = [] + + if access_type & BreakpointType.READ: + mem_sets.append((seg.read_idxs, seg.read_addrs, seg.read_masks)) + if access_type & BreakpointType.WRITE: + mem_sets.append((seg.write_idxs, seg.write_addrs, seg.write_masks)) + + # loop through the read / write memory sets for this segment + for idxs, addrs, masks in mem_sets: + hits, first_hit = {}, len(addrs) + + # + # check each 'aligned address' (actually an id #) within + # the given region to see if it appears anywhere in the + # current segment's memory set + # + + for address_id, address_mask in targets: + + # + # if there is a memory access to the region, we will + # break here and begin processing it + # + + try: + index = addrs.index(address_id) + first_hit = min(index, first_hit) + #print(f"HIT ON 0x{self.trace.mem_addrs[address_id]:08X} @ IDX {seg_base+idxs[index]}") + + # + # no hits for any bytes within this aligned address, + # try the next aligned address within the region + # + + except ValueError: + continue + + hits[address_id] = address_mask + + # + # if we hit this, it means no memory accesses of this + # type (eg, reads) occured to the region of memory in + # this segment. + # + # there's nothing else to process for this memory set, + # so just break and move onto the next set (eg, writes) + # + + if not hits: + continue + + for index in range(first_hit, len(addrs)): + address_id = addrs[index] + target_mask = hits.get(address_id, None) + + if not target_mask: + continue + + #print("CLOSE! DOES MASK MATCH?") + #print(f" TARGET: 0x{self.trace.mem_addrs[address_id]:08X} MASK: {target_mask:02X}") + #print(f" CURRENT: 0x{self.trace.mem_addrs[address_id]:08X} MASK: {masks[index]:02X}") + #print(f" RESULT: {target_mask & masks[index]:02X}") + + # + # got the first hit for this set.. great! save it and + # break to search the next memory set + # + + if target_mask & masks[index]: + hit_idx = seg_base + idxs[index] + if hit_idx < idx: + continue + accesses.append(hit_idx) + #print(f"FOUND HIT AT IDX {hit_idx}") + break + + # + # if there has been a read or a write, select the one that is + # 'closest' to our current idx. there should only be, at most, + # two elements in this list... + # + + if accesses: + #print("ALL ACCESSES", accesses) + return min(accesses, key=lambda x:abs(x-idx)) + + # fail, reached end of trace + return -1 + + def find_prev_region_read(self, address, length, idx=None): + """ + Return the previous timestamp to read from the given memory region. + """ + return self.find_prev_region_access(address, length, idx, BreakpointType.READ) + + def find_prev_region_write(self, address, length, idx=None): + """ + Return the previous timestamp to write to the given memory region. + """ + return self.find_prev_region_access(address, length, idx, BreakpointType.WRITE) + + def find_prev_region_access(self, address, length, idx=None, access_type=BreakpointType.ACCESS): + """ + Return the previous timestamp to access the given memory region. + """ + if idx is None: + idx = self.idx - 1 + + #logger.debug(f"FIND PREV REGION ACCESS FOR 0x{address:08X} -> 0x{address+length:08X} STARTING AT IDX {idx:,}") + + accesses, mem_sets = [], [] + targets = self._region_to_targets(address, length) + starting_segment = self.trace.get_segment(idx) + + for seg_id in range(starting_segment.id, -1, -1): + + # fetch a segment to search backwards through + seg = self.trace.segments[seg_id] + seg_base = seg.base_idx + + mem_sets = [] + + if access_type & BreakpointType.READ: + mem_sets.append((seg.read_idxs, seg.read_addrs, seg.read_masks)) + if access_type & BreakpointType.WRITE: + mem_sets.append((seg.write_idxs, seg.write_addrs, seg.write_masks)) + + # loop through the read / write memory sets for this segment + for idxs, addrs, masks in mem_sets: + reverse_addrs = addrs[::-1] + hits, first_hit = {}, len(reverse_addrs) + + # + # check each 'aligned address' (actually an id #) within + # the given region to see if it appears anywhere in the + # current segment's memory set + # + + for address_id, address_mask in targets: + + # + # if there is a memory access to the region, we will + # break here and begin processing it + # + + try: + index = reverse_addrs.index(address_id) + first_hit = min(index, first_hit) + #print(f"HIT ON 0x{self.trace.mem_addrs[address_id]:08X} @ IDX {seg_base+idxs[index]}") + + # + # no hits for any bytes within this aligned address, + # try the next aligned address within the region + # + + except ValueError: + continue + + # + # ignore hits that are less than the starting timestamp + # because we are searching FORWARD, deeper into time + # + + #if seg_base + idxs[index] <= idx: + # print(f"TOSSING {seg_base+idxs[index]:,}, TOO CLOSE!") + # continue + + hits[address_id] = address_mask + + # + # if we hit this, it means no memory accesses of this + # type (eg, reads) occured to the region of memory in + # this segment. + # + # there's nothing else to process for this memory set, + # so just break and move onto the next set (eg, writes) + # + + if not hits: + continue + + num_addrs = len(reverse_addrs) + for reverse_index in range(first_hit, num_addrs): + address_id = reverse_addrs[reverse_index] + target_mask = hits.get(address_id, None) + + if not target_mask: + continue + + #print("CLOSE! DOES MASK MATCH?") + #print(f" TARGET: 0x{self.trace.mem_addrs[address_id]:08X} MASK: {target_mask:02X}") + #print(f" CURRENT: 0x{self.trace.mem_addrs[address_id]:08X} MASK: {masks[index]:02X}") + #print(f" RESULT: {target_mask & masks[index]:02X}") + + normal_index = num_addrs - reverse_index - 1 + + # + # got the first hit for this set.. great! save it and + # break to search the next memory set + # + + if target_mask & masks[normal_index]: + hit_idx = seg_base + idxs[normal_index] + if hit_idx > idx: + continue + accesses.append(hit_idx) + #print(f"FOUND HIT AT IDX {hit_idx}") + break + + # + # if there has been a read or a write, select the one that is + # 'closest' to our current idx. there should only be, at most, + # two elements in this list... + # + + if accesses: + return min(accesses, key=lambda x:abs(x-idx)) + + # fail, reached end of trace + return -1 + + def find_next_register_change(self, reg_name, idx=None): + """ + Return the next timestamp to change the given register. + """ + if idx is None: + idx = self.idx + 1 + + # if the idx is invalid, then there is nothing to do + if not(0 < idx < self.trace.length): + return -1 + + starting_segment = self.trace.get_segment(idx) + target_mask_ids = self.trace.get_reg_mask_ids_containing(reg_name) + + # search forward through the remaining segments + for seg_id in range(starting_segment.id , len(self.trace.segments)): + seg = self.trace.segments[seg_id] + seg_base = seg.base_idx + + # + # we only need to search *part* of the current segment, start + # from the given/starting idx position + # + + if seg == starting_segment: + relative_idx = idx - starting_segment.base_idx + + # for the remaining segments, we need to search them from the start + else: + relative_idx = 0 + + # search forward through the starting segment + while relative_idx < seg.length: + if seg.reg_masks[relative_idx] in target_mask_ids: + return seg.base_idx + relative_idx + relative_idx += 1 + + # fail, reached end of trace + return -1 + + def find_prev_register_change(self, reg_name, idx=None): + """ + Return the prev timestamp to change the given register. + """ + + # + # search backwards from the current trace position if a starting + # position is not specified + # + + if idx is None: + idx = self.idx - 1 + + # if the idx is invalid, then there is nothing to do + if not(0 < idx < self.trace.length): + return -1 + + starting_segment = self.trace.get_segment(idx) + target_mask_ids = self.trace.get_reg_mask_ids_containing(reg_name) + + # search backwards through the remaining segments + for seg_id in range(starting_segment.id, -1, -1): + seg = self.trace.segments[seg_id] + + # + # we only need to search *part* of the current segment, start + # from the given/starting idx position + # + + if seg == starting_segment: + relative_idx = idx - starting_segment.base_idx + + # + # for the remaining segments, we need to search them + # back to front, as we are iterating backwards in time + # + + else: + relative_idx = seg.length - 1 + + # search forward through the starting segment + while relative_idx > -1: + if seg.reg_masks[relative_idx] in target_mask_ids: + return seg.base_idx + relative_idx + relative_idx -= 1 + + # fail, reached end of trace + return -1 + + def _region_to_targets(self, address, length): + """ + Convert an (address, len) region definition into a list of [(addr_id, access_mask), ...]. + """ + ADDRESS_ALIGMENT = 8 # TODO: this is gross! + output = [] + + # + # convert the given contiguous region of memory into an array of aligned + # addresses and memory masks to mirror the 'compressed' trace format + # + + aligned_address = self.trace.get_aligned_address(address) + aligned_mask = self.trace.get_aligned_address_mask(address) + + mapped_address = self.trace.get_mapped_address(address) + if mapped_address != -1: + output.append((mapped_address, aligned_mask)) + #print(f"aligned: 0x{aligned_address} - mask {aligned_mask}") + + # the bytes consumed so far + length -= (ADDRESS_ALIGMENT - (address - aligned_address)) + aligned_address += ADDRESS_ALIGMENT + + # process the remaining.. aligned.. addresses + while length > 0: + + mapped_address = self.trace.get_mapped_address(aligned_address) + + # + # the current chunk of the region is not seen in the trace, skip + # to the next chunk + # + + if mapped_address == -1: + length -= ADDRESS_ALIGMENT + aligned_address += ADDRESS_ALIGMENT + continue + + mask_length = ADDRESS_ALIGMENT if length > ADDRESS_ALIGMENT else length + access_mask = self.trace.get_aligned_address_mask(aligned_address, mask_length) + #print(f"aligned: 0x{aligned_address:08X} - mask {access_mask:02X} - mask len {mask_length}") + + output.append((mapped_address, access_mask)) + + # continue moving through the region + length -= ADDRESS_ALIGMENT + aligned_address += ADDRESS_ALIGMENT + + #for addr, mask in output: + # print(f"TARGET {self.trace.mem_addrs[addr]:08X} MASK {mask:02X}") + + return output + + #------------------------------------------------------------------------- + # State API + #------------------------------------------------------------------------- + + def get_ip(self, idx=None): + """ + Return the instruction pointer. + + If a timestamp (idx) is provided, that will be used instead of the current timestamp. + """ + return self.trace.get_ip(idx) + + def get_register(self, reg_name, idx=None): + """ + Return a single register value. + + If a timestamp (idx) is provided, that will be used instead of the current timestamp. + """ + return self.get_registers([reg_name], idx)[reg_name] + + def get_registers(self, reg_names=None, idx=None): + """ + Return a dict of the requested registers and their values. + + If a list of registers (reg_names) is not provided, all registers will be returned. + + If a timestamp (idx) is provided, that will be used instead of the current timestamp. + """ + if idx is None: + idx = self.idx + + # no registers were specified, so we'll return *all* registers + if reg_names is None: + reg_names = self.arch.REGISTERS.copy() + + # + # if the query matches the cached (most recently acces) + # + + output_registers, target_registers = {}, reg_names.copy() + + # sanity checks + for reg_name in target_registers: + if not reg_name in self.arch.REGISTERS: + raise ValueError(f"Invalid register name: '{reg_name}'") + + # + # fast path / LRU cache of 1, pickup any registers that we've already + # queried for this timestamp and remove them from the search + # + + if idx == self._idx_cached_registers: + for name in reg_names: + if name in self._cached_registers: + output_registers[name] = self._cached_registers[name] + target_registers.remove(name) + + # + # the trace PC is stored differently, and is tacked on at the end of + # the query (if it is requested). we remove it here so we don't search + # for it in the main register query logic + # + + include_ip = False + if self.arch.IP in target_registers: + include_ip = True + target_registers.remove(self.arch.IP) + + # + # looks like everything is resolved from the cache already? so we + # can just return early... + # + + if not target_registers: + if include_ip: + output_registers[self.arch.IP] = self.trace.get_ip(idx) + return output_registers + + # + # search for the desired register values + # + + current_idx = idx + segment = self.trace.get_segment(idx) + + while segment: + + # fetch the registers of interest + found_registers = segment.get_reg_info(current_idx, target_registers) + for reg_name, info in found_registers.items(): + + # alias the reg info + reg_value, reg_idx = info + + # save the found register + output_registers[reg_name] = reg_value + + # discard the found register from the search set + target_registers.remove(reg_name) + + #print(f"Finished Seg #{segment.id}, still missing {target_registers}") + + # found all the desired registers! + if not target_registers: + break + + # TODO/XXX: uhf, this '-2' is ugly. should probably refactor. but we have to + # do -2 because get_reg_info() searches from idx + 1.. so -2 into the + # prev segment.. +1 will put us on the last idx of the segment... + + # move to the next segment if there are still registers to find... + current_idx = segment.base_idx - 2 + segment = self.trace.get_segment(current_idx) + + # fetch IP, if it was requested + if include_ip: + output_registers[self.arch.IP] = self.trace.get_ip(idx) + + # update the set of cached registers + if self._idx_cached_registers == idx: + self._cached_registers.update(output_registers) + else: + self._cached_registers = output_registers + + # the timestamp for the cached register set + self._idx_cached_registers = idx + + # return the register set for this trace index + return output_registers + + def get_memory(self, address, length, idx=None): + """ + Return the requested memeory. + + If a timestamp (idx) is provided, that will be used instead of the current timestamp. + """ + if idx is None: + idx = self.idx + + #print(f"STARTING MEM FETCH AT IDX {idx} (reader @ {self.idx})") + buffer = TraceMemory(address, length) + + # + # translate the (address, len) 'region' definition to a set of pointer + # width (eg, 8 byte) aligned addresses as used internally by the trace + # + + aligned_addresses = {(((address + i) >> 3) << 3) for i in range(length)} + + get_mapped_address = self.trace.get_mapped_address + mem_addrs = self.trace.mem_addrs + mem_masks = self.trace.mem_masks + + missing_mem = {} + for address in aligned_addresses: + + # translate the aligned addresses to their mapped addresses (a simple id) + mapped_address = get_mapped_address(address) + #print(f"SHOULD SEARCH? {address:08X} --> {mapped_address}") + + # + # if the symbolic address (a mapped id) doesn't appear in the trace + # at all, there is no need to try and fetch mem for it + # + + if mapped_address == -1: + continue + + # + # save the mask for what bytes at the aligned address should + # exist in the trace + # + + missing_mem[mapped_address] = mem_masks[mapped_address] + #print(f"MISSING 0x{address:08x} - MASK {mem_masks[mapped_address]:02X}") + + missing_mem.pop(-1, None) + + # + # + # + + starting_seg = self.trace.get_segment(idx) + seg = starting_seg + + # NOTE: writes should have priority in this list + mem_sets = \ + [ + (seg.read_idxs, seg.read_addrs, seg.read_masks), + (seg.write_idxs, seg.write_addrs, seg.write_masks), + ] + + segment_hits = {} + + # + # loop backwards through the read / write memory sets for the segment + # this get_memory() request started from (eg, the current trace position) + # + + for set_id, entries in enumerate(mem_sets): + idxs, addrs, masks = entries + + # + # slice the memory set down to just the memory accesses that occur + # before the starting idx/timestamp + # + + relative_idx = idx - starting_seg.base_idx + #print(f"ATTEMPTING TO SLICE AT RELATIVE IDX {relative_idx} (idx {idx})") + + index = bisect.bisect_right(idxs, relative_idx) + idxs = idxs[:index] + addrs = addrs[:index] + masks = masks[:index] + + # + # loop backwards through the memory access list, as we need + # to find the last-known access to a given address + # + + for hit_id in range(len(addrs) - 1, -1, -1): + current_address = addrs[hit_id] + missing_mask = missing_mem.get(current_address, 0) + #print(f"MEM ACCESS {self.trace.mem_addrs[current_address]:08X}") + #print(f" - MISSING MASK? {missing_mask:02X}") + + # the current memory access does not fall into the region + # we care about... ignore it and keep moving + if not masks[hit_id] & missing_mask: + continue + + # found a hit, save its info to evaluate after hits have + # been scraped from both sets + hits = segment_hits.setdefault(current_address, []) + hits.append((idxs[hit_id], set_id, hit_id)) + + # + # we have collected all the reads/writes to the region of interest + # for this segment... now we will go through each one until we have + # enumerated the most recent data from the lists of memory accesses + # + + for mapped_address, hits in segment_hits.items(): + #print(f"PROCESSING HIT {self.trace.mem_addrs[mapped_address]:08X}") + + # + # sort the hits to an aligned address by highest idx (most-recent) + # NOTE: mem set id will be the second sort param (writes take precedence) + # + + hits = sorted(hits, reverse=True) + #print(hits) + + # + # go through each hit for the aligned address, until its value + # has been fully resolved + # + + for relative_idx, set_id, hit_id in hits: + idxs, addrs, masks = mem_sets[set_id] + + missing_mask = missing_mem[mapped_address] + current_mask = masks[hit_id] + + #assert relative_idx < (idx - seg.base_idx), f"rel {relative_idx} vs {idx} .. {idx - seg.base_idx}" + #print(f"rel {relative_idx} vs {idx} .. {idx - seg.base_idx}") + + # if this access doesn't contain any new data of interest, ignore it + if not missing_mask & current_mask: + continue + + found_mask = missing_mask & current_mask + found_mem = seg.get_mem_data(hit_id, set_id, found_mask) + #print(f"FOUND MEM {found_mem} FOUND MASK {found_mask:02X}") + #print(f" - ADDR: 0x{found_mem.address:08X}") + #print(f" - BADDR: 0x{buffer.address:08X}, LEN {buffer.length}") + + # update the output buffer with the found memory + buffer.update(found_mem) + + # update the missing mask bits + missing_mask &= ~found_mask + + # the current address has had all of it bytes resolved + # back to a concrete values, time to bail + if not missing_mask: + missing_mem.pop(mapped_address) + break + + missing_mem[mapped_address] = missing_mask + + # + # now we will go backwards through the trace segment snapshots and + # attempt to resolve the remaining missing memory + # + + for seg_id in range(starting_seg.id-1, -1, -1): + + seg = self.trace.segments[seg_id] + mem_delta = seg.mem_delta + + to_remove = [] + + # + # loop through all the addresses that we are still missing data + # for, and check if this segment can resolve it to a concrete value + # + + for mapped_address, missing_mask in missing_mem.items(): + + # skip the current address if it doesn't get touched by this seg + if not(mapped_address in mem_delta): + continue + + # + # fetch the 'value' (1-8 bytes) that this segment sets at the + # the current aligned address + # + + mv = mem_delta[mapped_address] + + # + # if the bytes set aren't ones that we are still looking for, + # then there is nothing to fetch for this address, in this seg + # + + if not (missing_mask & mv.mask): + continue + + # + # create a mask of the missing bytes, that we can resolve with + # the memory value (mv) provided by this snapshot + # + + found_mask = missing_mask & mv.mask + + # remove the bits that this memory value will resolve + missing_mask &= ~found_mask + if not missing_mask: + to_remove.append(mapped_address) + + other_address = mem_addrs[mapped_address] + if other_address < buffer.address: + buffer_index = 0 + other_index = buffer.address - other_address + else: + buffer_index = other_address - buffer.address + other_index = 0 + + buffer_remaining = buffer.length - buffer_index + other_remaining = 8 - other_index + overlap = min(buffer_remaining, other_remaining) + + #print(f"HIT 0x{other_address:08X} IN SEG {seg_id} (started from {starting_seg.id})", ' '.join(["%02X" % x for x in mv.value])) + for i in range(overlap): + if (found_mask >> (other_index+i)) & 1: + #print(f"- GRABBING BYTE @ 0x{other_address+other_index+i:08X}, ({mv.value[other_index+i]:02X})") + buffer.data[buffer_index+i] = mv.value[other_index+i] + buffer.mask[buffer_index+i] = 0xFF + + missing_mem[mapped_address] = missing_mask + + # remove any addresses that have had their values fully resolved + for mapped_address in to_remove: + missing_mem.pop(mapped_address) + + #print("STILL MISSING", ["0x%08X" % self.trace.mem_addrs[x] for x in missing_mem]) + + # return the final / found buffer + return buffer + + def read_pointer(self, address, idx=None): + """ + Read and return a pointer at the given address from memory. + + If the value cannot be fully resolved and returned, ValueError is raised. + """ + if idx is None: + idx = self.idx + + buffer = self.get_memory(address, self.arch.POINTER_SIZE, idx) + if not len(set(buffer.mask)) == 1 and buffer.mask[0] == 0xFF: + raise ValueError("Could not fully resolve memory at address") + + pack_fmt = 'Q' if self.arch.POINTER_SIZE == 8 else 'I' + return struct.unpack(pack_fmt, buffer.data)[0] + + #---------------------------------------------------------------------- + # Callbacks + #---------------------------------------------------------------------- + + def idx_changed(self, callback): + """ + Subscribe a callback for a trace navigation event. + """ + register_callback(self._idx_changed_callbacks, callback) + + def _notify_idx_changed(self): + """ + Notify listeners of an idx changed event. + """ + notify_callback(self._idx_changed_callbacks, self.idx) diff --git a/plugins_sogen-support/tenet/trace/types.py b/plugins_sogen-support/tenet/trace/types.py new file mode 100644 index 0000000..be3f747 --- /dev/null +++ b/plugins_sogen-support/tenet/trace/types.py @@ -0,0 +1,84 @@ +import array + +class TraceMemory(object): + """ + A Trace Memory Buffer. + + TODO: this is pretty trash / overraught and should be refactored. also + this can probably be moved into tenet.types? + """ + + def __init__(self, address, length): + self.address = address + self.data = array.array('B', [0]) * length + self.mask = array.array('B', [0]) * length + + def __contains__(self, address): + if self.address <= address < self.end_address: + return True + return False + + @property + def end_address(self): + return self.address + self.length + + @property + def length(self): + return len(self.data) + + def consume(self, other): + assert other.address >= self.address + + end_address = max(self.end_address, other.end_address) + new_length = end_address - self.address + + # + # if the other buffer is outside the memory region of this object, + # extend our region to include it + # + + if new_length > self.length: + new_data = array.array('B', [0]) * new_length + new_mask = array.array('B', [0]) * new_length + new_data[:self.length] = self.data[:self.length] + new_mask[:self.length] = self.mask[:self.length] + self.data = new_data + self.mask = new_mask + + # transfer data from the other memory object, into this one + base_idx = other.address - self.address + for i in range(other.length): + index = base_idx + i + if other.mask[i]: + self.data[index] = other.data[i] + self.mask[index] = 0xFF + + def update(self, other): + + if self.address < other.address: + this_start = other.address - self.address + other_start = 0 + else: + this_start = 0 + other_start = self.address - other.address + + assert this_start >= 0, f"{this_start} must be >= 0" + assert other_start >= 0, f"{other_start} must be >= 0" + + other_length_left = other.length - other_start + this_length_left = self.length - this_start + overlapped_length = min(other_length_left, this_length_left) + + #print('-'*50) + #print(f" Self Addr 0x{self.address:08X}, Len {self.length}") + #print(f"Other Addr 0x{other.address:08X}, Len {other.length}") + #print(f" Overlapping Bytes: {overlapped_length}, self start {this_start}, other start {other_start}") + + for i in range(overlapped_length): + if other.mask[other_start+i]: + self.data[this_start+i] = other.data[other_start+i] + self.mask[this_start+i] = 0xFF + + def __str__(self): + output = ["%02X" % byte if mask else "??" for byte, mask in zip(self.data, self.mask)] + return ' '.join(output) \ No newline at end of file diff --git a/plugins_sogen-support/tenet/types.py b/plugins_sogen-support/tenet/types.py new file mode 100644 index 0000000..23fefed --- /dev/null +++ b/plugins_sogen-support/tenet/types.py @@ -0,0 +1,72 @@ +import enum + +#----------------------------------------------------------------------------- +# types.py -- Plugin Types +#----------------------------------------------------------------------------- +# +# This purpose of this file is to host basic types / primitievs that +# may need to be used cross-plugin, and could be prone to causing +# cyclic dependency problems if left with their respective subsystems. +# + +#----------------------------------------------------------------------------- +# Hexdump Types +#----------------------------------------------------------------------------- + +class HexType(enum.Enum): + BYTE = 0 + SHORT = 1 + DWORD = 2 + QWORD = 3 + POINTER = 4 + MAGIC = 5 + +class AuxType(enum.Enum): + NONE = 0 + ASCII = 1 + STACK = 2 + +HEX_TYPE_WIDTH = \ +{ + HexType.BYTE: 1, + HexType.SHORT: 2, + HexType.DWORD: 4, + HexType.QWORD: 8, + HexType.POINTER: 8, # XXX: should be 4 or 8 + HexType.MAGIC: 1, +} + +class HexItem(object): + def __init__(self, value, mask, width, item_type): + self.value = value + self.mask = mask + self.width = width # width in bytes + self.type = item_type + +#----------------------------------------------------------------------------- +# Breakpoint Types +#----------------------------------------------------------------------------- + +class BreakpointType(enum.IntEnum): + NONE = 1 << 0 + READ = 1 << 1 + WRITE = 1 << 2 + EXEC = 1 << 3 + ACCESS = (READ | WRITE) + +class BreakpointEvent(enum.Enum): + ADDED = 0 + REMOVED = 1 + ENABLED = 2 + DISABLED = 3 + +class TraceBreakpoint(object): + """ + A simple class to encapsulate the properties of a breakpoint definition. + """ + def __init__(self, address, access_type=BreakpointType.NONE, length=1): + assert not(address is None) + self.type = access_type + self.address = address + self.length = length + self.enabled = True \ No newline at end of file diff --git a/plugins_sogen-support/tenet/ui/__init__.py b/plugins_sogen-support/tenet/ui/__init__.py new file mode 100644 index 0000000..afd5e8f --- /dev/null +++ b/plugins_sogen-support/tenet/ui/__init__.py @@ -0,0 +1,8 @@ +from tenet.util.qt import QT_AVAILABLE + +# import Qt based plugin UI if available +if QT_AVAILABLE: + from tenet.ui.palette import PluginPalette + from tenet.ui.hex_view import HexView + from tenet.ui.reg_view import RegisterView + from tenet.ui.breakpoint_view import BreakpointView diff --git a/plugins_sogen-support/tenet/ui/__pycache__/__init__.cpython-311.pyc b/plugins_sogen-support/tenet/ui/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..880afa8 Binary files /dev/null and b/plugins_sogen-support/tenet/ui/__pycache__/__init__.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/ui/__pycache__/breakpoint_view.cpython-311.pyc b/plugins_sogen-support/tenet/ui/__pycache__/breakpoint_view.cpython-311.pyc new file mode 100644 index 0000000..4648562 Binary files /dev/null and b/plugins_sogen-support/tenet/ui/__pycache__/breakpoint_view.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/ui/__pycache__/hex_view.cpython-311.pyc b/plugins_sogen-support/tenet/ui/__pycache__/hex_view.cpython-311.pyc new file mode 100644 index 0000000..407d157 Binary files /dev/null and b/plugins_sogen-support/tenet/ui/__pycache__/hex_view.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/ui/__pycache__/palette.cpython-311.pyc b/plugins_sogen-support/tenet/ui/__pycache__/palette.cpython-311.pyc new file mode 100644 index 0000000..8b14515 Binary files /dev/null and b/plugins_sogen-support/tenet/ui/__pycache__/palette.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/ui/__pycache__/reg_view.cpython-311.pyc b/plugins_sogen-support/tenet/ui/__pycache__/reg_view.cpython-311.pyc new file mode 100644 index 0000000..91f86c1 Binary files /dev/null and b/plugins_sogen-support/tenet/ui/__pycache__/reg_view.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/ui/__pycache__/trace_view.cpython-311.pyc b/plugins_sogen-support/tenet/ui/__pycache__/trace_view.cpython-311.pyc new file mode 100644 index 0000000..1427875 Binary files /dev/null and b/plugins_sogen-support/tenet/ui/__pycache__/trace_view.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/ui/breakpoint_view.py b/plugins_sogen-support/tenet/ui/breakpoint_view.py new file mode 100644 index 0000000..ad33a7c --- /dev/null +++ b/plugins_sogen-support/tenet/ui/breakpoint_view.py @@ -0,0 +1,45 @@ +# +# TODO: I don't think this file is even in use right now, but w/e +# we'll ship it for now... +# + +from tenet.util.qt import * + +class BreakpointDock(QtWidgets.QDockWidget): + """ + Dockable wrapper of a Breakpoint view. + """ + def __init__(self, view, parent=None): + super(BreakpointDock, self).__init__(parent) + self.setAllowedAreas(QtCore.Qt.AllDockWidgetAreas) + self.setWindowTitle("Breakpoints") + self.setWidget(view) + +class BreakpointView(QtWidgets.QWidget): + """ + The Breakpoint Widget (UI) + """ + + def __init__(self, controller, model, parent=None): + super(BreakpointView, self).__init__(parent) + self.controller = controller + self.model = model + self._init_ui() + + def _init_ui(self): + self.setMinimumHeight(100) + + self._init_table() + + layout = QtWidgets.QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._table) + + def _init_table(self): + self._table = QtWidgets.QTableWidget(self) + self._table.insertColumn(0) + self._table.insertColumn(1) + self._table.insertColumn(2) + self._table.insertColumn(3) + self._table.setHorizontalHeaderLabels(["Type", "Enabled", "Address", "Delete"]) + diff --git a/plugins_sogen-support/tenet/ui/hex_view.py b/plugins_sogen-support/tenet/ui/hex_view.py new file mode 100644 index 0000000..a1ab4dc --- /dev/null +++ b/plugins_sogen-support/tenet/ui/hex_view.py @@ -0,0 +1,1012 @@ +import struct + +from tenet.types import * +from tenet.util.qt import * + +INVALID_ADDRESS = -1 + +class HexView(QtWidgets.QAbstractScrollArea): + """ + A Qt based hex / memory viewer. + + Adapted from: + - https://github.com/virinext/QHexView + + """ + + def __init__(self, controller, model, parent=None): + super(HexView, self).__init__(parent) + self.controller = controller + self.model = model + self._palette = controller.pctx.palette + + self.setFocusPolicy(QtCore.Qt.StrongFocus) + self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) + + font = QtGui.QFont("Courier", pointSize=normalize_font(9)) + font.setStyleHint(QtGui.QFont.TypeWriter) + self.setFont(font) + self.setMouseTracking(True) + + fm = QtGui.QFontMetricsF(font) + self._char_width = fm.width('9') + self._char_height = int(fm.tightBoundingRect('9').height() * 1.75) + self._char_descent = self._char_height - fm.descent()*0.75 + + self._click_timer = QtCore.QTimer(self) + self._click_timer.setSingleShot(True) + self._click_timer.timeout.connect(self._commit_click) + + self._double_click_timer = QtCore.QTimer(self) + self._double_click_timer.setSingleShot(True) + + self.hovered_address = INVALID_ADDRESS + + self._selection_start = INVALID_ADDRESS + self._selection_end = INVALID_ADDRESS + + self._pending_selection_origin = INVALID_ADDRESS + self._pending_selection_start = INVALID_ADDRESS + self._pending_selection_end = INVALID_ADDRESS + + self._ignore_navigation = False + + self._init_ctx_menu() + + def _init_ctx_menu(self): + """ + Initialize the right click context menu actions. + """ + + # create actions to show in the context menu + self._action_copy = QtWidgets.QAction("Copy", None) + self._action_clear = QtWidgets.QAction("Clear mem breakpoints", None) + self._action_follow_in_dump = QtWidgets.QAction("Follow in dump", None) + + bp_types = \ + [ + ("Read", BreakpointType.READ), + ("Write", BreakpointType.WRITE), + ("Access", BreakpointType.ACCESS) + ] + + # + # break on action group + # + + self._action_break = {} + + for name, bp_type in bp_types: + action = QtWidgets.QAction(name, None) + action.setCheckable(True) + self._action_break[action] = bp_type + + self._break_menu = QtWidgets.QMenu("Break on...") + self._break_menu.addActions(self._action_break) + + # + # goto action groups + # + + self._action_first = {} + self._action_prev = {} + self._action_next = {} + self._action_final = {} + + for name, bp_type in bp_types: + self._action_prev[QtWidgets.QAction(name, None)] = bp_type + self._action_next[QtWidgets.QAction(name, None)] = bp_type + self._action_first[QtWidgets.QAction(name, None)] = bp_type + self._action_final[QtWidgets.QAction(name, None)] = bp_type + + self._goto_menus = \ + [ + (QtWidgets.QMenu("Go to first..."), self._action_first), + (QtWidgets.QMenu("Go to previous..."), self._action_prev), + (QtWidgets.QMenu("Go to next..."), self._action_next), + (QtWidgets.QMenu("Go to final..."), self._action_final), + ] + + for submenu, actions in self._goto_menus: + submenu.addActions(actions) + + # install the right click context menu + self.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) + self.customContextMenuRequested.connect(self._ctx_menu_handler) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def num_lines_visible(self): + """ + Return the number of lines visible in the hex view. + """ + area_size = self.viewport().size() + first_line_idx = self.verticalScrollBar().value() + last_line_idx = (first_line_idx + area_size.height() // self._char_height) + 1 + lines_visible = last_line_idx - first_line_idx + return lines_visible + + @property + def num_bytes_visible(self): + """ + Return the number of bytes visible in the hex view. + """ + return self.model.num_bytes_per_line * self.num_lines_visible + + @property + def selection_size(self): + """ + Return the number of bytes selected in the hex view. + """ + if self._selection_end == self._selection_start == INVALID_ADDRESS: + return 0 + return self._selection_end - self._selection_start + + @property + def hovered_breakpoint(self): + """ + Return the hovered breakpoint. + """ + if self.hovered_address == INVALID_ADDRESS: + return None + + for bp in self.model.memory_breakpoints: + if bp.address <= self.hovered_address < bp.address + bp.length: + return bp + + return None + + #------------------------------------------------------------------------- + # Internal + #------------------------------------------------------------------------- + + def refresh(self): + """ + Refresh the hex view. + """ + self._refresh_painting_metrics() + self.viewport().update() + + def _refresh_painting_metrics(self): + """ + Refresh any metrics and calculations required to paint the widget. + """ + + # 2 chars per byte of data, eg '00' + self._chars_in_line = self.model.num_bytes_per_line * 2 + + # add 1 char for each space between elements (bytes, dwords, qwords...) + self._chars_in_line += (self.model.num_bytes_per_line // HEX_TYPE_WIDTH[self.model.hex_format]) + + # the x position to draw the text address (left side of view) + self._pos_addr = self._char_width // 2 + + # the width of the column, 2 nibbles (chars) per byte of a pointer + # -- +1 for padding, (eg, 1/2 char on each side) + self._width_addr = (self.model.pointer_size * 2 + 1) * self._char_width + + # the x position and width of the hex bytes region (center section of view) + self._pos_hex = self._width_addr + self._char_width + self._width_hex = self._chars_in_line * self._char_width + + # the x position and width of the auxillary region (right section of view) + self._pos_aux = self._pos_hex + self._width_hex + self._width_aux = (self.model.num_bytes_per_line * self._char_width) + self._char_width * 2 + + # enforce a minimum view width, to ensure all text stays visible + self.setMinimumWidth(int(self._pos_aux + self._width_aux)) + + def full_size(self): + """ + TODO + """ + if not self.model.data: + return QtCore.QSize(0, 0) + + width = int(self._pos_aux + (self.model.num_bytes_per_line * self._char_width)) + height = len(self.model.data) // self.model.num_bytes_per_line + if len(self.model.data) % self.model.num_bytes_per_line: + height += 1 + + height *= self._char_height + + return QtCore.QSize(width, height) + + def point_to_index(self, position): + """ + Convert a QPoint (x, y) on the hex view window to a byte index. + + TODO/XXX: ugh this whole function / selection logic needs to be + rewritten... it's actually impossible to follow. + """ + padding = self._char_width // 2 + + if position.x() < (self._pos_hex - padding): + return -1 + + cutoff = self._pos_hex + self._width_hex - padding + #print(f"Position: {position} Cutoff: {cutoff} Pos Hex: {self._pos_hex} Width Hex: {self._width_hex} Padding: {padding}") + if position.x() >= cutoff: + return -1 + + # convert 'gloabl' x in the viewport, to an x that is 'relative' to the hex area + hex_x = (position.x() - self._pos_hex) + padding + #print("- Hex x", hex_x) + + # the number of items (eg, bytes, qwords) per line + num_items = self.model.num_bytes_per_line // HEX_TYPE_WIDTH[self.model.hex_format] + #print("- Num items", num_items) + + # compute the pixel width each rendered item on the line takes up + item_width = (self._char_width * 2) * HEX_TYPE_WIDTH[self.model.hex_format] + item_width_padded = item_width + self._char_width + #print("- Item Width", item_width) + #print("- Item Width Padded", item_width_padded) + + # compute the item index on a line (the x-axis) that the point falls within + item_index = int(hex_x // item_width_padded) + #print("- Item Index", item_index) + + # compute which byte is hovered in the item + if self.model.hex_format != HexType.BYTE: + + item_base_x = item_index * item_width_padded + (self._char_width // 2) + item_byte_x = hex_x - item_base_x + item_byte_index = int(item_byte_x // (self._char_width * 2)) + + # XXX: I give up, kludge to account for math errors + if item_byte_index < 0: + item_byte_index = 0 + elif item_byte_index >= self.model.num_bytes_per_line: + item_byte_index = self.model.num_bytes_per_line - 1 + + #print("- Item Byte X", item_byte_x) + #print("- Item Byte Index", item_byte_index) + + item_byte_index = (HEX_TYPE_WIDTH[self.model.hex_format] - 1) - item_byte_index + byte_x = item_index * HEX_TYPE_WIDTH[self.model.hex_format] + item_byte_index + + else: + byte_x = item_index * HEX_TYPE_WIDTH[self.model.hex_format] + + # compute the line number (the y-axis) that the point falls within + byte_y = position.y() // self._char_height + #print("- Byte (X, Y)", byte_x, byte_y) + + # compute the final byte index from the start address in the window + byte_index = (byte_y * self.model.num_bytes_per_line) + byte_x + #print("- Byte Index", byte_index) + + return byte_index + + def point_to_address(self, position): + """ + Convert a QPoint (x, y) on the hex view window to an address. + """ + byte_index = self.point_to_index(position) + if byte_index == -1: + return INVALID_ADDRESS + + byte_address = self.model.address + byte_index + return byte_address + + def point_to_breakpoint(self, position): + """ + Convert a QPoint (x, y) on the hex view window to a breakpoint. + """ + byte_address = self.point_to_address(position) + if byte_address == INVALID_ADDRESS: + return None + + for bp in self.model.memory_breakpoints: + if bp.address <= byte_address < bp.address + bp.length: + return bp + + return None + + def reset_selection(self): + """ + Clear the stored user memory selection. + """ + self._pending_selection_origin = INVALID_ADDRESS + self._pending_selection_start = INVALID_ADDRESS + self._pending_selection_end = INVALID_ADDRESS + self._selection_start = INVALID_ADDRESS + self._selection_end = INVALID_ADDRESS + + def _update_selection(self, position): + """ + Set the user memory selection. + """ + address = self.point_to_address(position) + if address == INVALID_ADDRESS: + return + + if address >= self._pending_selection_origin: + self._pending_selection_end = address + 1 + self._pending_selection_start = self._pending_selection_origin + else: + self._pending_selection_start = address + self._pending_selection_end = self._pending_selection_origin + 1 + + def _commit_click(self): + """ + Accept a click event. + """ + self._selection_start = self._pending_selection_start + self._selection_end = self._pending_selection_end + + self._pending_selection_origin = INVALID_ADDRESS + self._pending_selection_start = INVALID_ADDRESS + self._pending_selection_end = INVALID_ADDRESS + + self.viewport().update() + + def _commit_selection(self): + """ + Accept a selection event. + """ + self._selection_start = self._pending_selection_start + self._selection_end = self._pending_selection_end + + self._pending_selection_origin = INVALID_ADDRESS + self._pending_selection_start = INVALID_ADDRESS + self._pending_selection_end = INVALID_ADDRESS + + # notify listeners of our selection change + #self._notify_selection_changed(new_start, new_end) + self.viewport().update() + + #-------------------------------------------------------------------------- + # Signals + #-------------------------------------------------------------------------- + + def _ctx_menu_handler(self, position): + """ + Handle a right click event (populate/show context menu). + """ + menu = QtWidgets.QMenu() + + ctx_breakpoint = self.point_to_breakpoint(position) + ctx_address = self.point_to_address(position) + ctx_type = BreakpointType.NONE + + # + # determine the selection that the action will execute across + # + + if self._selection_start <= ctx_address < self._selection_end: + selected_address = self._selection_start + selected_length = self.selection_size + + elif ctx_breakpoint: + selected_address = ctx_breakpoint.address + selected_length = ctx_breakpoint.length + ctx_type = ctx_breakpoint.type + + else: + selected_address = INVALID_ADDRESS + selected_length = 0 + + # + # populate the popup menu + # + + # show the 'copy text' option if the user has a region selected + if selected_length > 1 and ctx_type == BreakpointType.NONE: + menu.addAction(self._action_copy) + + # only show the 'follow in dump' if the controller supports it + if hasattr(self.controller, "follow_in_dump"): + menu.addAction(self._action_follow_in_dump) + + menu.addSeparator() + + # show the break option only if there's a selection or breakpoint + if selected_length > 0: + menu.addMenu(self._break_menu) + menu.addSeparator() + + for action, access_type in self._action_break.items(): + action.setChecked(ctx_type == access_type) + + if selected_length > 0: + + # add the goto groups + for submenu, _ in self._goto_menus: + menu.addMenu(submenu) + + # show the 'clear breakpoints' action + menu.addSeparator() + menu.addAction(self._action_clear) + + # + # show the right click context menu + # + + action = menu.exec_(self.mapToGlobal(position)) + if not action: + return + + # + # execute the action selected by the suer in the right click menu + # + + if action == self._action_copy: + self.controller.copy_selection(self._selection_start, self._selection_end) + return + + elif action == self._action_follow_in_dump: + self.controller.follow_in_dump(self._selection_start) + return + + elif action == self._action_clear: + self.controller.pctx.breakpoints.clear_memory_breakpoints() + return + + # TODO: this is some of the shadiest/laziest code i've ever written + try: + selected_type = getattr(BreakpointType, action.text().upper()) + except: + pass + + if action in self._action_first: + self.controller.reader.seek_to_first(selected_address, selected_type, selected_length) + elif action in self._action_prev: + self.controller.reader.seek_to_prev(selected_address, selected_type, selected_length) + elif action in self._action_next: + self.controller.reader.seek_to_next(selected_address, selected_type, selected_length) + elif action in self._action_final: + self.controller.reader.seek_to_final(selected_address, selected_type, selected_length) + elif action in self._action_break: + self.controller.pin_memory(selected_address, selected_type, selected_length) + self.reset_selection() + + #---------------------------------------------------------------------- + # Qt Overloads + #---------------------------------------------------------------------- + + def mouseDoubleClickEvent(self, event): + """ + Qt overload to capture mouse double-click events. + """ + self._click_timer.stop() + + # + # if the double click fell within an active selection, we should + # consume the event as the user setting a region breakpoint + # + + if self._selection_start <= self._pending_selection_start < self._selection_end: + address = self._selection_start + size = self.selection_size + else: + address = self.point_to_address(event.pos()) + size = 1 + + self.controller.pin_memory(address, length=size) + self.reset_selection() + event.accept() + + self.viewport().update() + self._double_click_timer.start(100) + + def mouseMoveEvent(self, event): + """ + Qt overload to capture mouse movement events. + """ + mouse_position = event.pos() + + # update the hovered address + self.hovered_address = self.point_to_address(mouse_position) + + # mouse moving while holding left button + if event.buttons() == QtCore.Qt.MouseButton.LeftButton: + self._update_selection(mouse_position) + + # + # if the user is actively selecting bytes and has selected more + # than one byte, we should clear any existing selection. this will + # make it so the new ongoing 'pending' selection will get drawn + # + + if (self._pending_selection_end - self._pending_selection_start) > 1: + self._selection_start = INVALID_ADDRESS + self._selection_end = INVALID_ADDRESS + + self.viewport().update() + return + + def mousePressEvent(self, event): + """ + Qt overload to capture mouse button presses. + """ + if self._double_click_timer.isActive(): + return + + if event.button() == QtCore.Qt.LeftButton: + + byte_address = self.point_to_address(event.pos()) + + if not(self._selection_start <= byte_address < self._selection_end): + self.reset_selection() + + self._pending_selection_origin = byte_address + self._pending_selection_start = byte_address + self._pending_selection_end = (byte_address + 1) if byte_address != INVALID_ADDRESS else INVALID_ADDRESS + + self.viewport().update() + + def mouseReleaseEvent(self, event): + """ + Qt overload to capture mouse button releases. + """ + if self._double_click_timer.isActive(): + return + + # handle a right click + if event.button() == QtCore.Qt.RightButton: + + # get the address of the byte that was right clicked + byte_address = self.point_to_address(event.pos()) + if byte_address == INVALID_ADDRESS: + return + + # the right clicked fell within the current selection + if self._selection_start <= byte_address < self._selection_end: + return + + # the right click fell within an existing breakpoint + bp = self.hovered_breakpoint + if bp and (bp.address <= byte_address < bp.address + bp.length): + return + + # + # if the right click did not fall within any known selection / poi + # we should consume it and set the current cursor selection to it + # + + self._pending_selection_start = byte_address + self._pending_selection_end = byte_address + 1 + self._commit_click() + return + + if self._pending_selection_origin == INVALID_ADDRESS: + return + + # if the mouse press & release was on a single byte, it's a click + if (self._pending_selection_end - self._pending_selection_start) == 1: + + # + # if the click was within a selected region, defer acting on it + # for 500ms to see if a double click event occurs + # + + if self._selection_start <= self._pending_selection_start < self._selection_end: + self._click_timer.start(200) + return + else: + self._commit_click() + + # a range was selected, so accept/commit it + else: + self._commit_selection() + + def keyPressEvent(self, e): + """ + Qt overload to capture key press events. + """ + if e.key() == QtCore.Qt.Key_G: + import ida_kernwin, ida_idaapi + address = ida_kernwin.ask_addr(self.model.address, "Jump to address in memory") + if address != None and address != ida_idaapi.BADADDR: + self.controller.navigate(address) + e.accept() + return super(HexView, self).keyPressEvent(e) + + def wheelEvent(self, event): + """ + Qt overload to capture wheel events. + """ + + # + # first, we will attempt special handling of the case where a user + # 'scrolls' up or down when hovering their cursor over a byte they + # have selected... + # + + # compute the address of the hovered byte (if there is one...) + byte_address = self.point_to_address(event.pos()) + + for bp in self.model.memory_breakpoints: + + # skip this breakpoint if the current byte does not fall within its range + if not(bp.address <= byte_address < bp.address + bp.length): + continue + + # + # XXX: bit of a hack, but it seems like the easiest way to prevent + # the stack views from 'navigating' when you're hovering / scrolling + # through memory accesses (see _idx_changed in stack.py) + # + + self._ignore_navigation = True + + # + # if a region is selected with an 'access' breakpoint on it, + # use the start address of the selected region instead for + # the region-based seeks + # + + # scrolled 'up' + if event.angleDelta().y() > 0: + self.controller.reader.seek_to_prev(bp.address, bp.type, bp.length) + + # scrolled 'down' + elif event.angleDelta().y() < 0: + self.controller.reader.seek_to_next(bp.address, bp.type, bp.length) + + # restore navigation listening + self._ignore_navigation = False + + # consume the event + event.accept() + return + + # + # normal 'scroll' on the hex window.. scroll up or down into new + # regions of memory... + # + + if event.angleDelta().y() > 0: + self.controller.navigate(self.model.address - self.model.num_bytes_per_line) + + elif event.angleDelta().y() < 0: + self.controller.navigate(self.model.address + self.model.num_bytes_per_line) + + event.accept() + + def resizeEvent(self, event): + """ + Qt overload to capture resize events for the widget. + """ + super(HexView, self).resizeEvent(event) + self._refresh_painting_metrics() + self.controller.set_data_size(self.num_bytes_visible) + + #------------------------------------------------------------------------- + # Painting + #------------------------------------------------------------------------- + + def paintEvent(self, event): + """ + Qt overload of widget painting. + """ + if not self.model.data: + return + + painter = QtGui.QPainter(self.viewport()) + + # paint background of entire scroll area + painter.fillRect(event.rect(), self._palette.hex_data_bg) + + # paint address area background + address_area_rect = QtCore.QRect(0, event.rect().top(), int(self._width_addr), self.height()) + painter.fillRect(address_area_rect, self._palette.hex_address_bg) + + # paint line between address area and hex area + painter.setPen(self._palette.hex_separator) + painter.drawLine(int(self._width_addr), event.rect().top(), int(self._width_addr), self.height()) + + # paint line between hex area and auxillary area + line_pos = self._pos_aux + painter.setPen(self._palette.hex_separator) + painter.drawLine(int(line_pos), event.rect().top(), int(line_pos), self.height()) + + for line_idx in range(0, self.num_lines_visible): + self._paint_line(painter, line_idx) + + def _paint_line(self, painter, line_idx): + """ + Paint one line of hex. + """ + self._brush_default = painter.brush() + self._brush_selected = QtGui.QBrush(self._palette.standard_selection_bg) + self._brush_navigation = QtGui.QBrush(self._palette.navigation_selection_fg) + + # the pixel position to start painting from + x, y = self._pos_hex, (line_idx + 1) * self._char_height + + # clamp the address from 0 to 0xFFFFFFFFFFFFFFFF + address = self.model.address + (line_idx * self.model.num_bytes_per_line) + if address > 0xFFFFFFFFFFFFFFFF: + address = 0xFFFFFFFFFFFFFFFF + + address_color = self._palette.hex_address_fg + if address < self.model.fade_address: + address_color = self._palette.hex_text_faded_fg + + painter.setPen(address_color) + + # draw the address text + pack_len = self.model.pointer_size + address_fmt = '%016X' if pack_len == 8 else '%08X' + address_text = address_fmt % address + painter.drawText(int(self._pos_addr), y, address_text) + + self._default_color = self._palette.hex_text_fg + if address < self.model.fade_address: + self._default_color = self._palette.hex_text_faded_fg + + painter.setPen(self._default_color) + + byte_base_idx = line_idx * self.model.num_bytes_per_line + byte_idx = byte_base_idx + stop_idx = min(len(self.model.data), byte_base_idx + self.model.num_bytes_per_line) + + # paint each element on the line, up until the end of the line, or buffer + while byte_idx < stop_idx: + byte_idx, x, y = self._paint_hex_item(painter, byte_idx, stop_idx, x, y) + + assert byte_idx == stop_idx + + # + # paint 'readable' ASCII + # + + byte_idx = byte_base_idx + x_pos_aux = self._pos_aux + self._char_width + + if self.model.aux_format == AuxType.ASCII: + + for i in range(byte_base_idx, stop_idx): + + if self.model.mask[i]: + painter.setPen(self._default_color) + else: + painter.setPen(self._palette.hex_text_faded_fg) + + ch = self.model.data[i] + if ((ch < 0x20) or (ch > 0x7e)): + ch = '.' + else: + ch = chr(ch) + + painter.drawText(int(x_pos_aux), y, ch) + x_pos_aux += self._char_width + + def _paint_hex_item(self, painter, byte_idx, stop_idx, x, y): + """ + Paint a single hex item. + """ + + # draw single bytes + if self.model.hex_format == HexType.BYTE: + return self._paint_byte(painter, byte_idx, x, y) + + # draw dwords + elif self.model.hex_format == HexType.DWORD: + return self._paint_dword(painter, byte_idx, x, y) + + # draw qwords + elif self.model.hex_format == HexType.QWORD: + return self._paint_qword(painter, byte_idx, x, y) + + # identify and draw pointers + elif self.model.hex_format == HexType.MAGIC: + return self._paint_magic(painter, byte_idx, stop_idx, x, y) + + raise NotImplementedError("Unknown HexType format! %s" % self.model.hex_format) + + #return (byte_idx, x, y) + + def _paint_byte(self, painter, byte_idx, x, y): + """ + Paint a BYTE at the current position. + """ + self._paint_text(painter, byte_idx, 1, x, y) + x += (2 + 1) * self._char_width + + return (byte_idx + 1, x, y) + + def _paint_dword(self, painter, byte_idx, x, y): + """ + Paint a DWORD at the current position. + """ + backwards_idx = byte_idx - 1 + + for i in range(backwards_idx + 4, backwards_idx, -1): + self._paint_text(painter, i, 0, x, y) + x += self._char_width * 2 + + return (byte_idx + 4, x, y) + + def _paint_qword(self, painter, byte_idx, x, y): + """ + Paint a QWORD at the current position. + """ + backwards_idx = byte_idx - 1 + + for i in range(backwards_idx + 8, backwards_idx, -1): + self._paint_text(painter, i, 0, x, y) + x += self._char_width * 2 + + return (byte_idx + 8, x, y) + + def _paint_text(self, painter, byte_idx, padding, x, y): + + if self.model.mask[byte_idx]: + fg_color = self._default_color + text = "%02X" % self.model.data[byte_idx] + else: + fg_color = self._palette.hex_text_faded_fg + text = "??" + + # + # paint text selection background color / highlight + # + + x_bg = x - (self._char_width // 2) * padding + y_bg = y - self._char_descent + + width = self._char_width * (len(text) + padding) + height = self._char_height + + bg_color = None + border_color = None + + # compute the address of the byte we're drawing + byte_address = self.model.address + byte_idx + + # initialize selection start / end vars + start_address = INVALID_ADDRESS + end_address = INVALID_ADDRESS + + # fixed / committed selection + if self._selection_start != INVALID_ADDRESS: + start_address = self._selection_start + end_address = self._selection_end + + # active / on-going selection event + elif self._pending_selection_start != INVALID_ADDRESS: + start_address = self._pending_selection_start + end_address = self._pending_selection_end + + # a byte that falls within the user selection + if start_address <= byte_address < end_address: + bg_color = self._palette.standard_selection_bg + + # set the text color for selected text + if self.model.mask[byte_idx]: + fg_color = self._palette.standard_selection_fg + else: + fg_color = self._palette.standard_selection_faded_fg + + # a byte that was written + elif byte_address in self.model.delta.mem_writes: + bg_color = self._palette.mem_write_bg + fg_color = self._palette.mem_write_fg + + # a byte that was read + elif byte_address in self.model.delta.mem_reads: + bg_color = self._palette.mem_read_bg + fg_color = self._palette.mem_read_fg + + # a breakpoint byte + for bp in self.model.memory_breakpoints: + + # skip this breakpoint if the current byte does not fall within its range + if not(bp.address <= byte_address < bp.address + bp.length): + continue + + # + # if the breakpoint is a single byte, ensure it will always have a + # border around it, regardless of if it is selected, read, or + # written. + # + # this makes it easy to tell when you have selected or are hovering + # an active 'hot' byte / breakpoint that can be scrolled over to + # seek between accesses + # + + if bp.length == 1: + border_color = self._palette.navigation_selection_bg + + # + # if the background color for this byte has already been + # specified, that means a read/write probably occured to it so + # we should prioritize those colors OVER the breakpoint coloring + # + + if bg_color: + break + + # + # if the byte wasn't read/written/selected, we are free to color + # it red, as it falls within an active breakpoint region + # + + bg_color = self._palette.navigation_selection_bg + + # if the byte value is know (versus '??'), set its text color + if self.model.mask[byte_idx]: + fg_color = self._palette.navigation_selection_fg + else: + fg_color = self._palette.navigation_selection_faded_fg + + # + # no need to keep searching through breakpoints once the byte has + # been colored! break and go paint the byte... + # + + break + + # the byte is highlighted in some fashion, paint it now + if bg_color: + + if border_color: + pen = QtGui.QPen(border_color, 2) + pen.setJoinStyle(QtCore.Qt.MiterJoin) + painter.setPen(pen) + x_bg += 1 + y_bg += 1 + width -= 2 + height -= 2 + + else: + painter.setPen(QtCore.Qt.NoPen) + + painter.setBrush(bg_color) + painter.drawRect(int(x_bg), int(y_bg), int(width), int(height)) + + painter.setPen(fg_color) + + # + # paint text + # + + painter.drawText(int(x), y, text) + + def _paint_magic(self, painter, byte_idx, stop_idx, x, y): + """ + Perform magic painting at the current position. + + This will essentially try to identify pointers while painting, and + format them as appropriate. + + TODO: this needs to be updated to be truly pointer size agnostic + """ + + # not enough bytes left to identify / paint a pointer from the data + if byte_idx + self.model.pointer_size > stop_idx: + return self._paint_byte(painter, byte_idx, x, y) + + # ensure that all the bytes for the 'pointer' to analyze are known + pack_len = self.model.pointer_size + pack_fmt = 'Q' if pack_len == 8 else 'I' + mask = struct.unpack(pack_fmt, self.model.mask[byte_idx:byte_idx+pack_len])[0] + if mask != 0xFFFFFFFFFFFFFFFF: + return self._paint_byte(painter, byte_idx, x, y) + + # read and analyze the value to determine if it is a pointer + value = struct.unpack(pack_fmt, self.model.data[byte_idx:byte_idx+pack_len])[0] + if not self.controller.pctx.is_pointer(value): + return self._paint_byte(painter, byte_idx, x, y) + + # + # it seems like a pointer, let's draw one! + # + + # compute how many characters would have normally filled this space + # if inidividual bytes were printed instead... + num_chars = 3 * self.model.pointer_size + + # draw the pointer + pointer_str = ("0x%08X " % value).rjust(num_chars) + painter.drawText(int(x), y, pointer_str) + x += num_chars * self._char_width + + return (byte_idx + self.model.pointer_size, x, y) diff --git a/plugins_sogen-support/tenet/ui/palette.py b/plugins_sogen-support/tenet/ui/palette.py new file mode 100644 index 0000000..7921fb8 --- /dev/null +++ b/plugins_sogen-support/tenet/ui/palette.py @@ -0,0 +1,574 @@ +import os +import json +import shutil +import logging + +from json.decoder import JSONDecodeError + +from tenet.util.qt import * +from tenet.util.misc import * +from tenet.util.log import pmsg +from tenet.integration.api import disassembler + +logger = logging.getLogger("Plugin.UI.Palette") + +#------------------------------------------------------------------------------ +# Plugin Color Palette +#------------------------------------------------------------------------------ + +class PluginPalette(object): + """ + Theme palette for the plugin. + """ + + def __init__(self): + """ + Initialize default palette colors for the plugin. + """ + self._initialized = False + self._last_directory = None + self._required_fields = [] + + # hints about the user theme (light/dark) + self._user_qt_hint = "dark" + self._user_disassembly_hint = "dark" + + self.theme = None + self._default_themes = \ + { + "dark": "synth.json", + "light": "horizon.json" + } + + # list of objects requesting a callback after a theme change + self._theme_changed_callbacks = [] + + # get a list of required theme fields, for user theme validation + self._load_required_fields() + + # initialize the user theme directory + self._populate_user_theme_dir() + + # load a placeholder theme for inital Tenet bring-up + self._load_default_theme() + self._initialized = False + + @staticmethod + def get_plugin_theme_dir(): + """ + Return the plugin theme directory. + """ + return plugin_resource("themes") + + @staticmethod + def get_user_theme_dir(): + """ + Return the user theme directory. + """ + theme_directory = os.path.join( + disassembler.get_disassembler_user_directory(), + "tenet_themes" + ) + return theme_directory + + #---------------------------------------------------------------------- + # Callbacks + #---------------------------------------------------------------------- + + def theme_changed(self, callback): + """ + Subscribe a callback for theme change events. + """ + register_callback(self._theme_changed_callbacks, callback) + + def _notify_theme_changed(self): + """ + Notify listeners of a theme change event. + """ + notify_callback(self._theme_changed_callbacks) + + #---------------------------------------------------------------------- + # Public + #---------------------------------------------------------------------- + + def warmup(self): + """ + Warms up the theming system prior to initial use. + """ + if self._initialized: + return + + logger.debug("Warming up theme subsystem...") + + # attempt to load the user's preferred theme + if self._load_preferred_theme(): + self._initialized = True + logger.debug(" - warmup complete, using user theme!") + return + + # + # if no user selected theme is loaded, we will attempt to detect + # and load the in-box themes based on the disassembler theme + # + + if self._load_hinted_theme(): + logger.debug(" - warmup complete, using hint-recommended theme!") + self._initialized = True + return + + pmsg("Could not warmup theme subsystem!") + + def interactive_change_theme(self): + """ + Open a file dialog and let the user select a new plugin theme. + """ + + # create & configure a Qt File Dialog for immediate use + file_dialog = QtWidgets.QFileDialog( + None, + "Open plugin theme file", + self._last_directory, + "JSON Files (*.json)" + ) + file_dialog.setFileMode(QtWidgets.QFileDialog.ExistingFile) + + # prompt the user with the file dialog, and await filename(s) + filename, _ = file_dialog.getOpenFileName() + if not filename: + return + + # + # ensure the user is only trying to load themes from the user theme + # directory as it helps ensure some of our intenal loading logic + # + + file_dir = os.path.abspath(os.path.dirname(filename)) + user_dir = os.path.abspath(self.get_user_theme_dir()) + if file_dir != user_dir: + text = "Please install your plugin theme into the user theme directory:\n\n" + user_dir + disassembler.warning(text) + return + + # + # remember the last directory we were in (parsed from a selected file) + # for the next time the user comes to load a theme file + # + + if filename: + self._last_directory = os.path.dirname(filename) + os.sep + + # log the captured (selected) filenames from the dialog + logger.debug("Captured filename from theme file dialog: '%s'" % filename) + + # + # before applying the selected plugin theme, we should ensure that + # we know if the user is using a light or dark disassembler theme as + # it may change which colors get used by the plugin theme + # + + self._refresh_theme_hints() + + # if the selected theme fails to load, throw a visible warning + if not self._load_theme(filename): + disassembler.warning( + "Failed to load plugin user theme!\n\n" + "Please check the console for more information..." + ) + return + + # since everthing looks like it loaded okay, save this as the preferred theme + with open(os.path.join(self.get_user_theme_dir(), ".active_theme"), "w") as f: + f.write(filename) + + def refresh_theme(self): + """ + Dynamically compute palette color based on the disassembler theme. + + Depending on if the disassembler is using a dark or light theme, we + *try* to select colors that will hopefully keep things most readable. + """ + if self._load_preferred_theme(): + return + if self._load_hinted_theme(): + return + pmsg("Failed to refresh theme!") + + def gen_arrow_icon(self, color, rotation): + """ + Dynamically generate a colored/rotated arrow icon. + """ + icon_path = plugin_resource(os.path.join("icons", "arrow.png")) + + img = QtGui.QPixmap(icon_path) + + if rotation: + rm = QtGui.QTransform() + rm.rotate(rotation) + img = img.transformed(rm) + + mask = QtGui.QPixmap(img) + + p = QtGui.QPainter() + p.begin(mask) + p.setCompositionMode(QtGui.QPainter.CompositionMode_SourceIn) + p.fillRect(img.rect(), color) + p.end() + + p.begin(img) + p.setCompositionMode(QtGui.QPainter.CompositionMode_Overlay) + p.drawPixmap(0, 0, mask) + p.end() + + # convert QPixmap to bytes + ba = QtCore.QByteArray() + buff = QtCore.QBuffer(ba) + buff.open(QtCore.QIODevice.WriteOnly) + ok = img.save(buff, "PNG", quality=100) + assert ok + + return ba.data() + + #-------------------------------------------------------------------------- + # Theme Internals + #-------------------------------------------------------------------------- + + def _populate_user_theme_dir(self): + """ + Create the plugin's user theme directory and install default themes. + """ + + # create the user theme directory if it does not exist + user_theme_dir = self.get_user_theme_dir() + makedirs(user_theme_dir) + + # copy the default themes into the user directory if they don't exist + for theme_name in self._default_themes.values(): + + # + # check if the plugin has copied the default themes into the user + # theme directory before. when 'default' themes exists, skip them + # rather than overwriting... as the user may have modified it + # + + user_theme_file = os.path.join(user_theme_dir, theme_name) + if os.path.exists(user_theme_file): + continue + + # copy the in-box themes to the user theme directory + plugin_theme_file = os.path.join(self.get_plugin_theme_dir(), theme_name) + shutil.copy(plugin_theme_file, user_theme_file) + + # + # if the user tries to switch themes, ensure the file dialog will start + # in their user theme directory + # + + self._last_directory = user_theme_dir + + def _load_required_fields(self): + """ + Load the required theme fields from a donor in-box theme. + """ + logger.debug("Loading required theme fields from disk...") + + # load a known-good theme from the plugin's in-box themes + filepath = os.path.join(self.get_plugin_theme_dir(), self._default_themes["dark"]) + theme = self._read_theme(filepath) + + # + # save all the defined fields in this 'good' theme as a ground truth + # to validate user themes against... + # + + self._required_fields = theme["fields"].keys() + + def _load_default_theme(self): + """ + Load the default theme without any sort of hinting. + """ + theme_name = self._default_themes["dark"] + theme_path = os.path.join(self.get_plugin_theme_dir(), theme_name) + return self._load_theme(theme_path) + + def _load_hinted_theme(self): + """ + Load the in-box plugin theme hinted at by the theme subsystem. + """ + self._refresh_theme_hints() + + # + # we have two themes hints which roughly correspond to the tone of + # the user's disassembly background, and then the Qt subsystem. + # + # if both themes seem to align on style (eg the user is using a + # 'dark' UI), then we will select the appropriate in-box theme + # + + if self._user_qt_hint == self._user_disassembly_hint: + theme_name = self._default_themes[self._user_qt_hint] + logger.debug(" - No preferred theme, hints suggest theme '%s'" % theme_name) + + # + # the UI hints don't match, so the user is using some ... weird + # mismatched theming in their disassembler. let's just default to + # the 'dark' plugin theme as it is more robust + # + + else: + theme_name = self._default_themes["dark"] + + # build the filepath to the hinted, in-box theme + theme_path = os.path.join(self.get_plugin_theme_dir(), theme_name) + + # attempt to load and return the result of loading an in-box theme + return self._load_theme(theme_path) + + def _load_preferred_theme(self): + """ + Load the user's saved, preferred theme. + """ + logger.debug("Loading preferred theme from disk...") + user_theme_dir = self.get_user_theme_dir() + + # attempt te read the name of the user's active / preferred theme name + active_filepath = os.path.join(user_theme_dir, ".active_theme") + try: + theme_name = open(active_filepath).read().strip() + logger.debug(" - Got '%s' from .active_theme" % theme_name) + except (OSError, IOError): + return False + + # build the filepath to the user defined theme + theme_path = os.path.join(self.get_user_theme_dir(), theme_name) + + # finally, attempt to load & apply the theme -- return True/False + if self._load_theme(theme_path): + return True + + # + # failed to load the preferred theme... so delete the 'active' + # file (if there is one) and warn the user before falling back + # + + try: + os.remove(os.path.join(self.get_user_theme_dir(), ".active_theme")) + except: + pass + + disassembler.warning( + "Failed to load plugin user theme!\n\n" + "Please check the console for more information..." + ) + + return False + + def _validate_theme(self, theme): + """ + Pefrom rudimentary theme validation. + """ + logger.debug(" - Validating theme fields for '%s'..." % theme["name"]) + user_fields = theme.get("fields", None) + if not user_fields: + pmsg("Could not find theme 'fields' definition") + return False + + # check that all the 'required' fields exist in the given theme + for field in self._required_fields: + if field not in user_fields: + pmsg("Could not find required theme field '%s'" % field) + return False + + # theme looks good enough for now... + return True + + def _load_theme(self, filepath): + """ + Load and apply the plugin theme at the given filepath. + """ + + # attempt to read json theme from disk + try: + theme = self._read_theme(filepath) + + # reading file from dsik failed + except OSError: + pmsg("Could not open theme file at '%s'" % filepath) + return False + + # JSON decoding failed + except JSONDecodeError as e: + pmsg("Failed to decode theme '%s' to json" % filepath) + pmsg(" - " + str(e)) + return False + + # do some basic sanity checking on the given theme file + if not self._validate_theme(theme): + pmsg("Failed to validate theme '%s'" % filepath) + return False + + # try applying the loaded theme to the plugin + try: + self._apply_theme(theme) + except Exception as e: + pmsg("Failed to load the plugin user theme\n%s" % e) + return False + + # return success + self._notify_theme_changed() + return True + + def _read_theme(self, filepath): + """ + Parse the plugin theme file from the given filepath. + """ + logger.debug(" - Reading theme file '%s'..." % filepath) + + # attempt to load the theme file contents from disk + raw_theme = open(filepath, "r").read() + + # convert the theme file contents to a json object/dict + theme = json.loads(raw_theme) + + # all good + return theme + + def _apply_theme(self, theme): + """ + Apply the given theme definition to the plugin. + """ + logger.debug(" - Applying theme '%s'..." % theme["name"]) + colors = theme["colors"] + + for field_name, color_entry in theme["fields"].items(): + + # color has 'light' and 'dark' variants + if isinstance(color_entry, list): + color_name = self._pick_best_color(field_name, color_entry) + + # there is only one color defined + else: + color_name = color_entry + + # load the color + color_value = colors[color_name] + color = QtGui.QColor(*color_value) + + # set theme self.[field_name] = color + setattr(self, field_name, color) + + # all done, save the theme in case we need it later + self.theme = theme + + def _pick_best_color(self, field_name, color_entry): + """ + Given a variable color_entry, select the best color based on the theme hints. + + TODO: Most of this file is ripped from Lighthouse, including this func. In + Lighthouse is behaves a bit different than it does here, but I'm too lazy + to refactor/remove it for now (and maybe it'll get used later on??) + """ + assert len(color_entry) == 2, "Malformed color entry, must be (dark, light)" + dark, light = color_entry + + if self._user_qt_hint == "dark": + return dark + + return light + + #-------------------------------------------------------------------------- + # Theme Inference + #-------------------------------------------------------------------------- + + def _refresh_theme_hints(self): + """ + Peek at the UI context to infer what kind of theme the user might be using. + """ + self._user_qt_hint = self._qt_theme_hint() + self._user_disassembly_hint = self._disassembly_theme_hint() or "dark" + + def _disassembly_theme_hint(self): + """ + Binary hint of the disassembler color theme. + + This routine returns a best effort hint as to what kind of theme is + in use for the IDA Views (Disas, Hex, HexRays, etc). + + Returns 'dark' or 'light' indicating the user's theme + """ + + # + # determine whether to use a 'dark' or 'light' paint based on the + # background color of the user's disassembly text based windows + # + + bg_color = disassembler.get_disassembly_background_color() + if not bg_color: + logger.debug(" - Failed to get hint for disassembly background...") + return None + + # return 'dark' or 'light' + return test_color_brightness(bg_color) + + def _qt_theme_hint(self): + """ + Binary hint of the Qt color theme. + + This routine returns a best effort hint as to what kind of theme the + QtWdigets throughout IDA are using. This is to accomodate for users + who may be using Zyantific's IDASkins plugins (or others) to further + customize IDA's appearance. + + Returns 'dark' or 'light' indicating the user's theme + """ + + # + # to determine what kind of Qt based theme IDA is using, we create a + # test widget and check the colors put into the palette the widget + # inherits from the application (eg, IDA). + # + + test_widget = QtWidgets.QWidget() + + # + # in order to 'realize' the palette used to render (draw) the widget, + # it first must be made visible. since we don't want to be popping + # random widgets infront of the user, so we set this attribute such + # that we can silently bake the widget colors. + # + # NOTE/COMPAT: WA_DontShowOnScreen + # + # https://www.riverbankcomputing.com/news/pyqt-56 + # + # lmao, don't ask me why they forgot about this attribute from 5.0 - 5.6 + # + + if disassembler.NAME == "BINJA": + test_widget.setAttribute(QtCore.Qt.WA_DontShowOnScreen) + else: + test_widget.setAttribute(103) # taken from http://doc.qt.io/qt-5/qt.html + + # render the (invisible) widget + test_widget.show() + + # now we farm the background color from the qwidget + bg_color = test_widget.palette().color(QtGui.QPalette.Window) + + # 'hide' & delete the widget + test_widget.hide() + test_widget.deleteLater() + + # return 'dark' or 'light' + return test_color_brightness(bg_color) + +#----------------------------------------------------------------------------- +# Palette Util +#----------------------------------------------------------------------------- + +def test_color_brightness(color): + """ + Test the brightness of a color. + """ + if color.lightness() > 255.0/2: + return "light" + else: + return "dark" diff --git a/plugins_sogen-support/tenet/ui/reg_view.py b/plugins_sogen-support/tenet/ui/reg_view.py new file mode 100644 index 0000000..42c4732 --- /dev/null +++ b/plugins_sogen-support/tenet/ui/reg_view.py @@ -0,0 +1,545 @@ +import collections + +from tenet.types import BreakpointType +from tenet.util.qt import * +from tenet.integration.api import disassembler + +class RegisterView(QtWidgets.QWidget): + """ + A container for the the widgets that make up the Registers view. + """ + + def __init__(self, controller, model, parent=None): + super(RegisterView, self).__init__(parent) + self.controller = controller + self.model = model + self._init_ui() + + def _init_ui(self): + + # child widgets + self.reg_area = RegisterArea(self.controller, self.model, self) + self.idx_shell = TimestampShell(self.controller, self.model, self) + self.setMinimumWidth(self.reg_area.minimumWidth()) + + # layout + layout = QtWidgets.QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self.reg_area) + layout.addWidget(self.idx_shell) + self.setLayout(layout) + + def refresh(self): + self.reg_area.refresh() + self.idx_shell.update() + +class TimestampShell(QtWidgets.QWidget): + + def __init__(self, controller, model, parent=None): + super(TimestampShell, self).__init__(parent) + self.model = model + self.controller = controller + self._init_ui() + + def _init_ui(self): + + # child widgets + self.head = QtWidgets.QLabel("Position", self) + self.shell = TimestampLine(self.model, self.controller, self) + + # events + self.model.registers_changed(self.refresh) + + # layout + layout = QtWidgets.QHBoxLayout(self) + #layout.setContentsMargins(5, 0, 5, 0) + layout.setContentsMargins(5, 0, 0, 5) + layout.addWidget(self.head) + layout.addWidget(self.shell) + + def refresh(self): + self.shell.setText(f"{self.model.idx:,}") + +class TimestampLine(QtWidgets.QLineEdit): + def __init__(self, model, controller, parent=None): + super(TimestampLine, self).__init__(parent) + self.model = model + self.controller = controller + self._init_ui() + + def _init_ui(self): + self.setStyleSheet( + f""" + QLineEdit {{ + background-color: {self.controller.pctx.palette.reg_bg.name()}; + color: {self.controller.pctx.palette.reg_value_fg.name()}; + }} + """ + ) + self.returnPressed.connect(self._evaluate) + + def _evaluate(self): + self.controller.evaluate_expression(self.text()) + +class RegisterArea(QtWidgets.QAbstractScrollArea): + """ + A Qt-based CPU register view. + """ + def __init__(self, controller, model, parent=None): + super(RegisterArea, self).__init__(parent) + self.pctx = controller.pctx + self.controller = controller + self.model = model + + font = QtGui.QFont("Courier", pointSize=normalize_font(9)) + font.setStyleHint(QtGui.QFont.TypeWriter) + self.setFont(font) + + fm = QtGui.QFontMetricsF(font) + self._char_width = fm.width('9') + self._char_height = fm.height() + + # default to fit roughly 50 printable characters + self._default_width = self._char_width * (self.pctx.arch.POINTER_SIZE * 2 + 16) + + # register drawing information + self._reg_pos = (self._char_width, self._char_height) + self._reg_fields = {} + self._hovered_arrow = None + + self.setFocusPolicy(QtCore.Qt.StrongFocus) + self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) + self.setMinimumWidth(int(self._reg_pos[0] + self._default_width)) + self.setMouseTracking(True) + + self._init_ctx_menu() + self._init_reg_positions() + + self.model.registers_changed(self.refresh) + + def sizeHint(self): + width = int(self._default_width) + height = int((len(self._reg_fields) + 2) * self._char_height) # +2 for line break before IP, and after IP + return QtCore.QSize(width, height) + + def _init_ctx_menu(self): + """ + Initialize the right click context menu actions. + """ + + # create actions to show in the context menu + self._action_copy_value = QtWidgets.QAction("Copy value", None) + self._action_follow_in_dump = QtWidgets.QAction("Follow in dump", None) + self._action_follow_in_disassembly = QtWidgets.QAction("Follow in disassembler", None) + self._action_clear = QtWidgets.QAction("Clear code breakpoints", None) + + # install the right click context menu + self.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) + self.customContextMenuRequested.connect(self._ctx_menu_handler) + + def _init_reg_positions(self): + """ + Initialize register positions in the window. + """ + regs = self.model.arch.REGISTERS + name_x, y = self._reg_pos + + # find the most common length of a register name + reg_char_counts = collections.Counter([len(x) for x in regs]) + common_count, _ = reg_char_counts.most_common(1)[0] + + # compute rects for the average reg labels and values + fm = QtGui.QFontMetricsF(self.font()) + name_size = fm.boundingRect('X'*common_count).size() + value_size = fm.boundingRect('0' * (self.model.arch.POINTER_SIZE * 2)).size() + arrow_size = (int(value_size.height() * 0.70) | 1) + + # pre-compute the position of each register in the window + for reg_name in regs: + + # kind of dirty, but this will push IP a bit further away from the + # rest of the registers (it should be the last defined one...) + if reg_name == self.model.arch.IP: + y += self._char_height + + name_rect = QtCore.QRect(0, 0, int(name_size.width()), int(name_size.height())) + name_rect.moveBottomLeft(QtCore.QPoint(int(name_x), int(y))) + + prev_rect = QtCore.QRect(0, 0, int(arrow_size), int(arrow_size)) + next_rect = QtCore.QRect(0, 0, int(arrow_size), int(arrow_size)) + arrow_rects = [prev_rect, next_rect] + + prev_x = name_x + name_size.width() + self._char_width + prev_rect.moveCenter(name_rect.center()) + prev_rect.moveLeft(int(prev_x)) + + value_x = prev_x + prev_rect.width() + self._char_width + value_rect = QtCore.QRect(0, 0, int(value_size.width()), int(value_size.height())) + value_rect.moveBottomLeft(QtCore.QPoint(int(value_x), int(y))) + + next_x = value_x + value_size.width() + self._char_width + next_rect.moveCenter(name_rect.center()) + next_rect.moveLeft(int(next_x)) + + # save the register shapes + self._reg_fields[reg_name] = RegisterField(reg_name, name_rect, value_rect, arrow_rects) + + # increment y (to the next line) + y += self._char_height + + def _ctx_menu_handler(self, position): + """ + Handle a right click event (populate/show context menu). + """ + menu = QtWidgets.QMenu() + + # if a register was right clicked, fetch its name + reg_name = self._pos_to_reg(position) + if reg_name: + + # + # fetch the disassembler context and register value as we may use them + # based on the user's context, or the action they select + # + + dctx = disassembler[self.controller.pctx] + reg_value = self.model.registers[reg_name] + + # + # dynamically populate the right click context menu + # + + menu.addAction(self._action_copy_value) + menu.addAction(self._action_follow_in_dump) + + # + # if the register conatins a value that falls within the database, + # we want to show it and ensure it's active + # + + menu.addAction(self._action_follow_in_disassembly) + if dctx.is_mapped(reg_value): + self._action_follow_in_disassembly.setEnabled(True) + else: + self._action_follow_in_disassembly.setEnabled(False) + + # + # add a menu option to clear exection breakpoints if there is an + # active execution breakpoint set somewhere + # + + menu.addAction(self._action_clear) + + # + # show the right click menu and wait for the user to selection an + # action from the list of visible/active actions + # + + action = menu.exec_(self.mapToGlobal(position)) + + # + # handle the user selected action + # + + if action == self._action_copy_value: + copy_to_clipboard("0x%08X" % reg_value) + elif action == self._action_follow_in_disassembly: + dctx.navigate(reg_value) + elif action == self._action_follow_in_dump: + self.controller.follow_in_dump(reg_name) + elif action == self._action_clear: + self.pctx.breakpoints.clear_execution_breakpoints() + + def refresh(self): + self.viewport().update() + + def _pos_to_field(self, pos): + """ + Get the register field at the given cursor position. + """ + for reg_name, field in self._reg_fields.items(): + full_field = QtCore.QRect(field.name_rect.topLeft(), field.next_rect.bottomRight()) + if full_field.contains(pos): + return field + return None + + def _pos_to_reg(self, pos): + """ + Get the register name at the given cursor position. + """ + reg_field = self._pos_to_field(pos) + return reg_field.name if reg_field else None + + def full_size(self): + if not self.model.registers: + return QtCore.QSize(0, 0) + + width = int(self._reg_pos[0] + self._default_width) + height = int(len(self.model.registers) * self._char_height) + + return QtCore.QSize(width, height) + + def wheelEvent(self, event): + """ + Qt overload to capture wheel events. + """ + + # no execution breakpoints set, nothing to do + if not self.pctx.breakpoints.model.bp_exec: + return + + # mouse hover was not over IP register value, nothing to do + field = self._pos_to_field(event.pos()) + if not (field and field.name == self.model.arch.IP): + return + + # get the IP value currently displayed in the reg window + current_ip = self.model.registers[self.model.arch.IP] + breakpoints = self.pctx.breakpoints.model.bp_exec + + # loop through the execution-based breakpoints + for breakpoint_address in breakpoints: + if breakpoint_address == current_ip: + break + + # no execution breakpoints match the hovered IP + else: + return + + # scroll up + if event.angleDelta().y() > 0: + self.pctx.reader.seek_to_prev(current_ip, BreakpointType.EXEC) + + # scroll down + elif event.angleDelta().y() < 0: + self.pctx.reader.seek_to_next(current_ip, BreakpointType.EXEC) + + return + + def mouseMoveEvent(self, e): + """ + Qt overload to capture mouse movement events. + """ + point = e.pos() + before = self._hovered_arrow + + for reg_name, reg_field in self._reg_fields.items(): + if reg_field.next_rect.contains(point): + self._hovered_arrow = reg_field.next_rect + break + elif reg_field.prev_rect.contains(point): + self._hovered_arrow = reg_field.prev_rect + break + else: + self._hovered_arrow = None + + if before != self._hovered_arrow: + self.viewport().update() + + def mouseDoubleClickEvent(self, event): + """ + Qt overload to capture mouse double-click events. + """ + mouse_position = event.pos() + + # handle duoble (left) click events + if event.button() == QtCore.Qt.LeftButton: + + # confirm that we are consuming the double click event + event.accept() + + # check if the user clicked a known field + field = self._pos_to_field(mouse_position) + + # if the double click was *not* on a register field, clear execution breakpoints + if not field: + self.pctx.breakpoints.clear_execution_breakpoints() + return + + # ignore if the double clicked field (register) was not the IP reg + if not (field and field.name == self.model.arch.IP): + return + + # ignore if the double click was not on the reg value + if not field.value_rect.contains(mouse_position): + return + + # the user double clicked IP, so set a breakpoint on it + self.controller.set_ip_breakpoint() + + def mousePressEvent(self, event): + """ + Qt overload to capture mouse button presses. + """ + mouse_position = event.pos() + + # handle click events + if event.button() == QtCore.Qt.LeftButton: + + # check if the user clicked a known field + field = self._pos_to_field(mouse_position) + + # no field (register name, or register value) was selected + if not field: + self.controller.clear_register_focus() + + # the user clicked on the register value + elif field.value_rect.contains(mouse_position): + self.controller.focus_register_value(field.name) + + # the user clicked on the 'seek to next reg change' arrow + elif field.next_rect.contains(mouse_position): + result = self.pctx.reader.find_next_register_change(field.name) + if result != -1: + self.pctx.reader.seek(result) + + # the user clicked on the 'seek to prev reg change' arrow + elif field.prev_rect.contains(mouse_position): + result = self.pctx.reader.find_prev_register_change(field.name) + if result != -1: + self.pctx.reader.seek(result) + + # the user clicked on the register name + else: + self.controller.focus_register_name(field.name) + + # update the view as selection / drawing may change + self.viewport().update() + + def paintEvent(self, event): + """ + Qt overload of widget painting. + """ + + if not self.model.registers: + return + + painter = QtGui.QPainter(self.viewport()) + + area_size = self.viewport().size() + area_rect = self.viewport().rect() + widget_size = self.full_size() + + painter.fillRect(area_rect, self.pctx.palette.reg_bg) + + brush_defualt = painter.brush() + brush_selected = QtGui.QBrush(self.pctx.palette.standard_selection_bg) + + for reg_name in self.model.arch.REGISTERS: + reg_value = self.model.registers[reg_name] + reg_field = self._reg_fields[reg_name] + + # coloring for when the register is selected by the user + if reg_name == self.model.focused_reg_name: + painter.setBackground(brush_selected) + painter.setBackgroundMode(QtCore.Qt.OpaqueMode) + painter.setPen(self.pctx.palette.standard_selection_fg) + + # default / unselected register colors + else: + painter.setBackground(brush_defualt) + painter.setBackgroundMode(QtCore.Qt.OpaqueMode) + painter.setPen(self.pctx.palette.reg_name_fg) + + # draw register name + painter.drawText(reg_field.name_rect, QtCore.Qt.AlignCenter, reg_name) + + reg_nibbles = self.model.arch.POINTER_SIZE * 2 + if reg_value is None: + rendered_value = "?" * reg_nibbles + else: + rendered_value = f'%0{reg_nibbles}X' % reg_value + + # color register if its value changed as a result of T-1 (previous instr) + if reg_name in self.model.delta_trace: + painter.setPen(self.pctx.palette.reg_changed_trace_fg) + + # color register if its value changed as a result of navigation + # TODO: disabled for now, because it seemed more confusing than helpful... + elif reg_name in self.model.delta_navigation and False: + painter.setPen(self.pctx.palette.reg_changed_navigation_fg) + + # no special highlighting, default register value color text + else: + painter.setPen(self.pctx.palette.reg_value_fg) + + # coloring for when the register is selected by the user + if reg_name == self.model.focused_reg_value: + painter.setPen(self.pctx.palette.standard_selection_fg) + painter.setBackground(brush_selected) + painter.setBackgroundMode(QtCore.Qt.OpaqueMode) + + # default / unselected register colors + else: + painter.setBackground(brush_defualt) + painter.setBackgroundMode(QtCore.Qt.OpaqueMode) + + # special highlighting of the instruction pointer if it matches an active breakpoint + if reg_name == self.model.arch.IP: + if reg_value in self.model.execution_breakpoints: + painter.setPen(self.pctx.palette.navigation_selection_fg) + painter.setBackground(self.pctx.palette.navigation_selection_bg) + + # draw register value + painter.drawText(reg_field.value_rect, QtCore.Qt.AlignCenter, rendered_value) + + # don't draw arrows next to RIP's value + if reg_name == self.model.arch.IP: + continue + + # draw register arrows + for i, rect in enumerate([reg_field.prev_rect, reg_field.next_rect]): + self._draw_arrow(painter, rect, i) + + def _draw_arrow(self, painter, rect, index): + path = QtGui.QPainterPath() + + size = rect.height() + assert size % 2, "Cursor triangle size must be odd" + + # the top point of the triangle + top_x = rect.x() + (0 if index else rect.width()) + top_y = rect.y() + 1 + + # bottom point of the triangle + bottom_x = top_x + bottom_y = top_y + size - 1 + + # the 'tip' of the triangle pointing into towards the center of the trace + tip_x = top_x + ((size // 2) * (1 if index else -1)) + tip_y = top_y + (size // 2) + + # start drawing from the 'top' of the triangle + path.moveTo(top_x, top_y) + + # generate the triangle path / shape + path.lineTo(bottom_x, bottom_y) + path.lineTo(tip_x, tip_y) + path.lineTo(top_x, top_y) + + # dev / debug helper + #painter.setPen(QtCore.Qt.green) + #painter.setBrush(QtGui.QBrush(QtGui.QColor("white"))) + #painter.drawRect(rect) + + # paint the defined triangle + # TODO: don't hardcode colors + painter.setPen(QtCore.Qt.black) + + if self._hovered_arrow == rect: + if index: + painter.setBrush(self.pctx.palette.arrow_next) + else: + painter.setBrush(self.pctx.palette.arrow_prev) + else: + painter.setBrush(self.pctx.palette.arrow_idle) + + painter.drawPath(path) + +class RegisterField(object): + def __init__(self, name, name_rect, value_rect, arrow_rects): + self.name = name + self.name_rect = name_rect + self.value_rect = value_rect + self.prev_rect = arrow_rects[0] + self.next_rect = arrow_rects[1] \ No newline at end of file diff --git a/plugins_sogen-support/tenet/ui/resources/icons/arrow.png b/plugins_sogen-support/tenet/ui/resources/icons/arrow.png new file mode 100644 index 0000000..e7e7b1f Binary files /dev/null and b/plugins_sogen-support/tenet/ui/resources/icons/arrow.png differ diff --git a/plugins_sogen-support/tenet/ui/resources/themes/horizon.json b/plugins_sogen-support/tenet/ui/resources/themes/horizon.json new file mode 100644 index 0000000..2f5eed1 --- /dev/null +++ b/plugins_sogen-support/tenet/ui/resources/themes/horizon.json @@ -0,0 +1,94 @@ +{ + "name": "Horizon", + + "colors": + { + "black": [ 0, 0, 0], + "white": [255, 255, 255], + + "lightest_gray": [241, 241, 241], + "lighter_gray": [210, 210, 210], + "light_gray": [160, 160, 160], + "gray": [ 80, 80, 80], + "dark_gray": [ 40, 40, 40], + "darkest_gray": [ 25, 25, 25], + + "true_red": [255, 0, 0], + "lightest_red": [255, 80, 80], + "light_red": [170, 57, 57], + "red": [118, 0, 0], + "dark_red": [ 85, 0, 0], + + "orange": [255, 165, 0], + + "true_yellow": [255, 255, 0], + "yellow": [255, 193, 7], + + "true_green": [ 0, 255, 0], + "light_green": [128, 255, 128], + + "purple": [150, 20, 150], + "dark_purple": [ 38, 23, 88], + + "light_blue": [ 33, 159, 255], + "blue": [ 30, 136, 229], + "dark_blue": [ 58, 58, 128] + + }, + + "fields": + { + "navigation_selection_bg": "lightest_red", + "navigation_selection_fg": "white", + "navigation_selection_faded_fg": "light_red", + + "standard_selection_bg": "lighter_gray", + "standard_selection_fg": "black", + "standard_selection_faded_fg": "light_gray", + + "reg_bg": "white", + + "reg_name_fg": "black", + "reg_value_fg": "black", + "reg_changed_trace_fg": "true_red", + "reg_changed_navigation_fg": "orange", + + "hex_text_fg": "black", + "hex_text_faded_fg": "light_gray", + + "hex_address_fg": "darkest_gray", + "hex_address_bg": "lightest_gray", + + "hex_data_bg": "white", + "hex_separator": "black", + + "trace_bedrock": "dark_gray", + "trace_unmapped": "light_gray", + "trace_instruction": "dark_blue", + "trace_border": "gray", + + "trace_cell_wall": "light_gray", + "trace_cell_wall_contrast": "black", + + "trace_cursor": "true_red", + "trace_cursor_border": "black", + "trace_cursor_highlight": "true_green", + + "trace_selection": "true_green", + "trace_selection_border": "true_green", + + "breakpoint": "lightest_red", + "mem_read_bg": "yellow", + "mem_read_fg": "black", + "mem_write_bg": "light_blue", + "mem_write_fg": "white", + + "trail_backward": "lightest_red", + "trail_current": "light_green", + "trail_forward": "light_blue", + + "arrow_prev": "lightest_red", + "arrow_next": "light_blue", + "arrow_idle": "light_gray" + } +} \ No newline at end of file diff --git a/plugins_sogen-support/tenet/ui/resources/themes/synth.json b/plugins_sogen-support/tenet/ui/resources/themes/synth.json new file mode 100644 index 0000000..218dabb --- /dev/null +++ b/plugins_sogen-support/tenet/ui/resources/themes/synth.json @@ -0,0 +1,93 @@ +{ + "name": "Synth", + + "colors": + { + "black": [ 0, 0, 0], + "white": [255, 255, 255], + + "lightest_gray": [221, 221, 221], + "light_gray": [160, 160, 160], + "gray": [ 90, 90, 90], + "medium_gray": [ 60, 60, 60], + "dark_gray": [ 40, 40, 40], + "darkest_gray": [ 25, 25, 25], + + "true_red": [255, 0, 0], + "light_red": [170, 57, 57], + "red": [118, 0, 0], + "dark_red": [ 85, 0, 0], + + "orange": [255, 165, 0], + + "true_yellow": [255, 255, 0], + "yellow": [212, 194, 106], + + "true_green": [ 0, 255, 0], + + "purple": [150, 20, 150], + "dark_purple": [ 38, 23, 88], + + "light_blue": [144, 190, 252], + "blue": [ 30, 136, 229], + "dark_blue": [ 0, 0, 128] + + }, + + "fields": + { + "navigation_selection_bg": "light_red", + "navigation_selection_fg": "white", + "navigation_selection_faded_fg": "dark_red", + + "standard_selection_bg": "light_blue", + "standard_selection_fg": "black", + "standard_selection_faded_fg": "gray", + + "reg_bg": "darkest_gray", + + "reg_name_fg": "lightest_gray", + "reg_value_fg": "lightest_gray", + "reg_changed_trace_fg": "true_red", + "reg_changed_navigation_fg": "orange", + + "hex_text_fg": "lightest_gray", + "hex_text_faded_fg": "gray", + + "hex_address_fg": "light_gray", + "hex_address_bg": "dark_gray", + + "hex_data_bg": "darkest_gray", + "hex_separator": "black", + + "trace_bedrock": "darkest_gray", + "trace_unmapped": "medium_gray", + "trace_instruction": "dark_purple", + "trace_border": "gray", + + "trace_cell_wall": "medium_gray", + "trace_cell_wall_contrast": "black", + + "trace_cursor": "true_red", + "trace_cursor_border": "black", + "trace_cursor_highlight": "true_green", + + "trace_selection": "true_green", + "trace_selection_border": "true_green", + + "breakpoint": "light_red", + "mem_read_bg": "yellow", + "mem_read_fg": "black", + "mem_write_bg": "blue", + "mem_write_fg": "white", + + "trail_backward": "red", + "trail_current": "purple", + "trail_forward": "dark_blue", + + "arrow_prev": "true_red", + "arrow_next": "blue", + "arrow_idle": "gray" + } +} + diff --git a/plugins_sogen-support/tenet/ui/trace_view.py b/plugins_sogen-support/tenet/ui/trace_view.py new file mode 100644 index 0000000..ee05b02 --- /dev/null +++ b/plugins_sogen-support/tenet/ui/trace_view.py @@ -0,0 +1,1335 @@ +import logging + +from tenet.util.qt import * +from tenet.util.misc import register_callback, notify_callback +from tenet.integration.api import disassembler + +logger = logging.getLogger("Tenet.UI.TraceView") + +# +# TODO: BIG DISCLAIMER -- The trace visualization / window does *not* make +# use of the MVC pattern that the other widgets do. +# +# this is mainly due to the fact that it was prototyped last, and I haven't +# gotten around to moving the 'logic' out of window/widget classes and into +# a dedicated controller class. +# +# this will probably happen sooner than later, to keep everything consistent +# + +#------------------------------------------------------------------------------ +# TraceView +#------------------------------------------------------------------------------ + +INVALID_POS = -1 +INVALID_IDX = -1 +INVALID_DENSITY = -1 + +class TraceBar(QtWidgets.QWidget): + """ + A trace visualization. + """ + + def __init__(self, pctx, zoom=False, parent=None): + super(TraceBar, self).__init__(parent) + self.pctx = pctx + self.reader = None + self._is_zoom = zoom + + # misc qt/widget settings + self.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) + self.setMouseTracking(True) + self.setMinimumSize(32, 32) + self._resize_timer = QtCore.QTimer(self) + self._resize_timer.setSingleShot(True) + self._resize_timer.timeout.connect(self._resize_stopped) + + # the first and last visible idx in this visualization + self.start_idx = 0 + self.end_idx = 0 + self._end_idx_internal = 0 + self._last_trace_idx = 0 + + # the 'uncommitted' / in-progress user selection of a trace region + self._idx_pending_selection_origin = INVALID_IDX + self._idx_pending_selection_start = INVALID_IDX + self._idx_pending_selection_end = INVALID_IDX + + # the committed user selection of a trace region + self._idx_selection_start = INVALID_IDX + self._idx_selection_end = INVALID_IDX + + # the idxs that should be highlighted based on user queries + self._idx_reads = [] + self._idx_writes = [] + self._idx_executions = [] + + # the magnetism distance (in pixels) for cursor clicks on viz events + self._magnetism_distance = 4 + self._hovered_idx = INVALID_IDX + + # listen for breakpoint changed events + pctx.breakpoints.model.breakpoints_changed(self._breakpoints_changed) + + #---------------------------------------------------------------------- + # Styling + #---------------------------------------------------------------------- + + # the width (in pixels) of the border around the trace bar + self._trace_border = 1 + + # the width (in pixels) of the border around trace cells + self._cell_border = 0 # computed dynamically + self._cell_min_border = 1 + self._cell_max_border = 1 + + # the height (in pixels) of the trace cells + self._cell_height = 0 # computed dynamically + self._cell_min_height = 2 + self._cell_max_height = 10 + + # the amount of space between cells (in pixels) + # - NOTE: no limit to cell spacing at max magnification! + self._cell_spacing = 0 # computed dynamically + self._cell_min_spacing = self._cell_min_border + + # the width (in pixels) of the border around user region selection + self._selection_border = 2 + + # create the rest of the painting vars + self._init_painting() + + #---------------------------------------------------------------------- + # Callbacks + #---------------------------------------------------------------------- + + self._selection_changed_callbacks = [] + + def _init_painting(self): + """ + Initialize widget/trace painting elements. + """ + self._image_base = None + self._image_highlights = None + self._image_selection = None + self._image_border = None + self._image_cursor = None + self._image_final = None + + self._painter_base = None + self._painter_highlights = None + self._painter_selection = None + self._painter_border = None + self._painter_cursor = None + self._painter_final = None + + self._pen_cursor = QtGui.QPen(self.pctx.palette.trace_cursor_highlight, 1, QtCore.Qt.SolidLine) + + self._pen_selection = QtGui.QPen(self.pctx.palette.trace_selection, self._selection_border, QtCore.Qt.SolidLine) + self._brush_selection = QtGui.QBrush(QtCore.Qt.Dense6Pattern) + self._brush_selection.setColor(self.pctx.palette.trace_selection_border) + + self._last_hovered = INVALID_IDX + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def length(self): + """ + Return the number of idx visible in the trace visualization. + """ + return (self.end_idx - self.start_idx) + + @property + def cells_visible(self): + """ + Return True if the trace visualization is drawing as cells. + """ + return bool(self._cell_height) + + @property + def density(self): + """ + Return the density of idx (instructions) per y-pixel of the trace visualization. + """ + density = (self.length / (self.height() - self._trace_border * 2)) + if density > 0: + return density + return INVALID_DENSITY + + @property + def viz_rect(self): + """ + Return a QRect defining the drawable trace visualization. + """ + x, y = self.viz_pos + w, h = self.viz_size + return QtCore.QRect(x, y, w, h) + + @property + def viz_pos(self): + """ + Return (x, y) coordinates of the drawable trace visualization. + """ + return (self._trace_border, self._trace_border) + + @property + def viz_size(self): + """ + Return (width, height) of the drawable trace visualization. + """ + w = max(0, int(self.width() - (self._trace_border * 2))) + h = max(0, int(self.height() - (self._trace_border * 2))) + return (w, h) + + #------------------------------------------------------------------------- + # Public + #------------------------------------------------------------------------- + + def attach_reader(self, reader): + """ + Attach a trace reader to this controller. + """ + self.reset() + + # attach the new reader + self.reader = reader + + # initialize state based on the reader + self.set_bounds(0, reader.trace.length) + + # attach signals to the new reader + reader.idx_changed(self.refresh) + + def set_bounds(self, start_idx, end_idx): + """ + Set the idx bounds of the trace visualization. + """ + assert end_idx > start_idx, f"Invalid Bounds ({start_idx}, {end_idx})" + + # set the bounds of the trace + self.start_idx = max(0, start_idx) + self.end_idx = end_idx + self._end_idx_internal = end_idx + + # update drawing metrics, note that this can 'tweak' end_idx to improve cell rendering + self._refresh_painting_metrics() + + # compute the number of instructions visible + self._last_trace_idx = min(self.reader.trace.length, self.end_idx) + + # refresh/redraw relevant elements + self._refresh_trace_highlights() + self.refresh() + + # return the final / selected bounds + return (self.start_idx, self.end_idx) + + def set_selection(self, start_idx, end_idx): + """ + Set the selection region bounds. + """ + assert end_idx >= start_idx + self._idx_selection_start = start_idx + self._idx_selection_end = end_idx + self.refresh() + + def reset(self): + """ + Reset the trace visualization. + """ + self.reader = None + + self.start_idx = 0 + self.end_idx = 0 + self._last_trace_idx = 0 + + self._idx_pending_selection_origin = INVALID_IDX + self._idx_pending_selection_start = INVALID_IDX + self._idx_pending_selection_end = INVALID_IDX + + self._idx_selection_start = INVALID_IDX + self._idx_selection_end = INVALID_IDX + + self._idx_reads = [] + self._idx_writes = [] + self._idx_executions = [] + + self._refresh_painting_metrics() + self.refresh() + + def refresh(self, *args): + """ + Refresh the trace visualization. + """ + self.update() + + #---------------------------------------------------------------------- + # Qt Overloads + #---------------------------------------------------------------------- + + def mouseMoveEvent(self, event): + """ + Qt overload to capture mouse movement events. + """ + if not self.reader: + return + + # mouse moving while holding left button + if event.buttons() == QtCore.Qt.MouseButton.LeftButton: + self._update_selection(event.y()) + self.refresh() + return + + # simple mouse hover over viz + self._update_hover(event.y()) + self.refresh() + + def mousePressEvent(self, event): + """ + Qt overload to capture mouse button presses. + """ + if not self.reader: + return + + # left mouse button was pressed (but not yet released!) + if event.button() == QtCore.Qt.MouseButton.LeftButton: + idx_origin = self._pos2idx(event.y()) + self._idx_pending_selection_origin = idx_origin + self._idx_pending_selection_start = idx_origin + self._idx_pending_selection_end = idx_origin + + return + + def mouseReleaseEvent(self, event): + """ + Qt overload to capture mouse button releases. + """ + if not self.reader: + return + + # if the left mouse button was released... + if event.button() == QtCore.Qt.MouseButton.LeftButton: + + # + # no selection origin? this means the click probably started + # off this widget, and the user moved their mouse over viz + # ... before releasing... which is not something we care about + # + + if self._idx_pending_selection_origin == INVALID_IDX: + return + + # if the mouse press & release was on the same idx, probably a click + if self._idx_pending_selection_start == self._idx_pending_selection_end: + self._commit_click() + + # a range was selected, so accept/commit it + else: + self._commit_selection() + + def leaveEvent(self, _): + """ + Qt overload to capture the mouse hover leaving the widget. + """ + self._hovered_idx = INVALID_IDX + self.refresh() + + def wheelEvent(self, event): + """ + Qt overload to capture wheel events. + """ + if not self.reader: + return + + # holding the shift key while scrolling is used to 'step over' + mod_keys = QtGui.QGuiApplication.keyboardModifiers() + step_over = bool(mod_keys & QtCore.Qt.ShiftModifier) + + # scrolling up, so step 'backwards' through the trace + if event.angleDelta().y() > 0: + self.reader.step_backward(1, step_over) + + # scrolling down, so step 'forwards' through the trace + elif event.angleDelta().y() < 0: + self.reader.step_forward(1, step_over) + + self.refresh() + event.accept() + + def resizeEvent(self, _): + """ + Qt overload to capture resize events for the widget. + """ + self._resize_timer.start(500) + + #------------------------------------------------------------------------- + # Helpers (Internal) + #------------------------------------------------------------------------- + # + # NOTE: this stuff should probably only be called by the 'mainthread' + # to ensure density / viz dimensions and stuff don't change. + # + + def _resize_stopped(self): + """ + Delayed handler of resize events. + + We delay handling resize events because several resize events can + trigger when a user is dragging to resize a window. we only really + care to recompute the visualization when they stop 'resizing' it. + """ + self.set_bounds(self.start_idx, self._end_idx_internal) + + def _refresh_painting_metrics(self): + """ + Refresh any metrics and calculations required to paint the widget. + """ + self._cell_height = 0 + self._cell_border = 0 + self._cell_spacing = 0 + + # how many 'instruction' cells *must* be shown based on current selection? + num_cell = self._end_idx_internal - self.start_idx + if not num_cell: + return + + # how many 'y' pixels are available, per cell (including spacing, between cells) + _, viz_h = self.viz_size + given_space_per_cell = viz_h / num_cell + + # compute the smallest possible cell height, with overlapping cell borders + min_full_cell_height = self._cell_min_height + self._cell_min_border + + # don't draw the trace vizualization as cells if the density is too high + if given_space_per_cell < min_full_cell_height: + #logger.debug(f"No need for cells -- {given_space_per_cell}, min req {min_full_cell_height}") + return + + # compute the pixel height of a cell at maximum height (including borders) + max_cell_height_with_borders = self._cell_max_height + self._cell_max_border * 2 + + # compute how much leftover space there is to use between cells + spacing_between_max_cells = given_space_per_cell - max_cell_height_with_borders + + # maximum sized instruction cells, with 'infinite' possible spacing between cells + if spacing_between_max_cells > max_cell_height_with_borders: + self._cell_border = self._cell_max_border + self._cell_height = self._cell_max_height + self._cell_spacing = spacing_between_max_cells + return + + # dynamically compute cell dimensions for drawing + self._cell_height = max(self._cell_min_height, min(int(given_space_per_cell * 0.95), self._cell_max_height)) + self._cell_border = max(self._cell_min_border, min(int(given_space_per_cell * 0.05), self._cell_max_border)) + self._cell_spacing = int(given_space_per_cell - (self._cell_height + self._cell_border * 2)) + #logger.debug(f"Dynamic cells -- Given: {given_space_per_cell}, Height {self._cell_height}, Border: {self._cell_border}, Spacing: {self._cell_spacing}") + + # if there's not enough to justify having spacing, use shared borders between cells (usually very small cells) + if self._cell_spacing < self._cell_min_spacing: + self._cell_spacing = self._cell_min_border * -2 + + # compute the final number of y pixels used by each 'cell' (an executed instruction) + used_space_per_cell = self._cell_height + self._cell_border * 2 + self._cell_spacing + + # compute how many cells we can *actually* show in the space available + num_cell_allowed = int(viz_h / used_space_per_cell) + 1 + #logger.debug(f"Num Cells {num_cell} vs Available Space {num_cell_allowed}") + + self.end_idx = self.start_idx + num_cell_allowed + + def _idx2pos(self, idx): + """ + Translate a given idx to its first Y coordinate. + """ + if idx < self.start_idx or idx >= self.end_idx: + #logger.warn(f"idx2pos failed (start: {self.start_idx:,} idx: {idx:,} end: {self.end_idx:,}") + return INVALID_POS + + density = self.density + if density == INVALID_DENSITY: + #logger.warn(f"idx2pos failed (INVALID_DENSITY)") + return INVALID_POS + + # convert the absolute idx to one that is 'relative' to the viz + relative_idx = idx - self.start_idx + + # re-base y to the start of the viz region + _, y = self.viz_pos + + # + # compute and return an 'approximate' y position of the given idx + # when the visualization is not using cell metrics (too dense) + # + + if not self.cells_visible: + y += int(relative_idx / density) + + # sanity check + _, viz_y = self.viz_pos + _, viz_h = self.viz_size + assert y >= viz_y + assert y < (viz_y + viz_h) + + # return the approximate y position of the given timestamp + return y + + #assert self._cell_spacing % 2 == 0 + + # compute the y position of the 'first' cell + y += self._cell_spacing // 2 # pad out from top + y += self._cell_border # top border of cell + + # compute the y position of any given cell after the first + y += self._cell_height * relative_idx # cell body + y += self._cell_border * relative_idx # cell bottom border + y += self._cell_spacing * relative_idx # full space between cells + y += self._cell_border * relative_idx # cell top border + + # return the y position of the cell corresponding to the given timestamp + return y + + def _pos2idx(self, y): + """ + Translate a given Y coordinate to an approximate idx. + """ + _, viz_y = self.viz_pos + _, viz_h = self.viz_size + + # clamp clearly out-of-bounds requests to the start/end idx values + if y < viz_y: + return self.start_idx + elif y >= viz_y + viz_h: + return self.end_idx - 1 + + density = self.density + if density == INVALID_DENSITY: + #logger.warn(f"pos2idx failed (INVALID_DENSITY)") + return INVALID_IDX + + # translate/rebase global y to viz relative y + y -= self._trace_border + + # compute the relative idx based on how much space is used per cell + if self.cells_visible: + + # this is how many vertical pixel each cell uses, including spacing to the next cell + used_space_per_cell = self._cell_height + self._cell_border * 2 + self._cell_spacing + + # compute relative idx for cell-based views + y -= self._cell_border + relative_idx = int(y / used_space_per_cell) + + # compute the approximate relative idx using the instruction density metric + else: + relative_idx = round(y * density) + + # convert the viz-relative idx, to its global trace idx timestamp + idx = self.start_idx + relative_idx + + # clamp idx to the start / end of visible tracebar range + return self._clamp_idx(idx) + + def _compute_pixel_distance(self, y, idx): + """ + Compute the pixel distance from a given Y to an idx. + """ + + # get the y pixel position of the given idx + y_idx = self._idx2pos(idx) + if y_idx == INVALID_POS: + return -1 + + # + # if the visualization drawing cells, adjust the reported y coordinate + # of the given idx to the center of the cell. this makes distance + # calculations more correct + # + + if self.cells_visible: + y_idx += int(self._cell_height/2) + + # return the on-screen pixel distance between the two y coords + return abs(y - y_idx) + + def _update_hover(self, current_y): + """ + Update the trace visualization based on the mouse hover. + """ + self._hovered_idx = INVALID_IDX + + # see if there's an interesting trace event close to the hover + hovered_idx = self._pos2idx(current_y) + closest_idx = self._get_closest_highlighted_idx(hovered_idx) + + # + # if the closest highlighted event (mem access, breakpoint) + # is outside the trace view bounds, then we don't need to + # do any special hover highlighting... + # + + if not(self.start_idx <= closest_idx < self.end_idx): + return + + # + # compute the on-screen pixel distance between the hover and the + # closest highlighted event + # + + px_distance = self._compute_pixel_distance(current_y, closest_idx) + #logger.debug(f"hovered idx {hovered_idx:,}, closest idx {closest_idx:,}, dist {px_distance} (start: {self.start_idx:,} end: {self.end_idx:,}") + if px_distance == -1: + return + + # clamp the lock-on distance depending on the scale of zoom / cell size + lockon_distance = max(self._magnetism_distance, self._cell_height) + + # + # if the trace event is within the magnetized distance of the user + # cursor, lock on to it. this makes 'small' things easier to click + # + + if px_distance < lockon_distance: + self._hovered_idx = closest_idx + + def _update_selection(self, y): + """ + Update the user region selection of the trace visualization based on the current y. + """ + idx_event = self._pos2idx(y) + + if idx_event > self._idx_pending_selection_origin: + self._idx_pending_selection_start = self._idx_pending_selection_origin + self._idx_pending_selection_end = idx_event + else: + self._idx_pending_selection_end = self._idx_pending_selection_origin + self._idx_pending_selection_start = idx_event + + self._idx_selection_start = INVALID_IDX + self._idx_selection_end = INVALID_IDX + + def _global_selection_changed(self, start_idx, end_idx): + """ + Handle selection behavior specific to a 'global' trace visualizations. + """ + if start_idx == end_idx: + return + self.set_selection(start_idx, end_idx) + + def _zoom_selection_changed(self, start_idx, end_idx): + """ + Handle selection behavior specific to a 'zoomer' trace visualizations. + """ + if start_idx == end_idx: + self.hide() + else: + self.show() + self.set_bounds(start_idx, end_idx) + + def _commit_click(self): + """ + Accept a click event. + """ + selected_idx = self._idx_pending_selection_start + + # use a 'magnetized' selection, if available + if self._hovered_idx != INVALID_IDX: + selected_idx = self._hovered_idx + self._hovered_idx = INVALID_IDX + + # reset pending selection + self._idx_pending_selection_origin = INVALID_IDX + self._idx_pending_selection_start = INVALID_IDX + self._idx_pending_selection_end = INVALID_IDX + + # does the click fall within the existing selected region? + within_region = (self._idx_selection_start <= selected_idx <= self._idx_selection_end) + + # nope click is outside the region, so clear the region selection + if not within_region: + self._idx_selection_start = INVALID_IDX + self._idx_selection_end = INVALID_IDX + self._notify_selection_changed(INVALID_IDX, INVALID_IDX) + + #print(f"Jumping to {selected_idx:,}") + self.reader.seek(selected_idx) + self.refresh() + + def _commit_selection(self): + """ + Accept a selection event. + """ + new_start = self._idx_pending_selection_start + new_end = self._idx_pending_selection_end + + # reset pending selections + self._idx_pending_selection_origin = INVALID_IDX + self._idx_pending_selection_start = INVALID_IDX + self._idx_pending_selection_end = INVALID_IDX + + # + # if we just selected a new region on a trace viz that's a + # 'zoomer', then we will apply the zoom-in action to ourself by + # adjusting our visible regions (bounds) + # + # NOTE: that we don't have to do this on a global / static trace + # viz, because the 'zoomers' will be notified as a listener of + # the selection change events + # + + if self._is_zoom: + + # + # ensure the committed selection is also reset as we are about + # to zoom-in and should not have an active selection once done + # + + self._idx_selection_start = INVALID_IDX + self._idx_selection_end = INVALID_IDX + + # + # apply the new zoom-in / viz bounds to ourself + # + # NOTE: because the special cell-drawing metrics / computation, set + # bounds can 'tweak' the end value, so we want to grab it here + # + + new_start, new_end = self.set_bounds(new_start, new_end) + + # commit the new selection for global trace visualizations + else: + self._idx_selection_start = new_start + self._idx_selection_end = new_end + + # notify listeners of our selection change + self._notify_selection_changed(new_start, new_end) + + def _get_closest_highlighted_idx(self, idx): + """ + Return the closest idx (timestamp) to the given idx. + """ + closest_idx = INVALID_IDX + smallest_distace = 999999999999999999999999 + for entries in [self._idx_reads, self._idx_writes, self._idx_executions]: + for current_idx in entries: + distance = abs(idx - current_idx) + if distance < smallest_distace: + closest_idx = current_idx + smallest_distace = distance + return closest_idx + + def _breakpoints_changed(self): + """ + The focused breakpoint has changed. + """ + self._refresh_trace_highlights() + self.refresh() + + def _refresh_trace_highlights(self): + """ + Refresh trace event / highlight info from the underlying trace reader. + """ + self._idx_reads = [] + self._idx_writes = [] + self._idx_executions = [] + + reader, density = self.reader, self.density + if not (reader and density != INVALID_DENSITY): + return + + model = self.pctx.breakpoints.model + + # fetch executions for all breakpoints + for bp in model.bp_exec.values(): + executions = reader.get_executions_between(bp.address, self.start_idx, self.end_idx, density) + self._idx_executions.extend(executions) + + # fetch all memory read (only) breakpoints hits + for bp in model.bp_read.values(): + if bp.length == 1: + reads = reader.get_memory_reads_between(bp.address, self.start_idx, self.end_idx, density) + else: + reads = reader.get_memory_region_reads_between(bp.address, bp.length, self.start_idx, self.end_idx, density) + self._idx_reads.extend(reads) + + # fetch all memory write (only) breakpoint hits + for bp in model.bp_write.values(): + if bp.length == 1: + writes = reader.get_memory_writes_between(bp.address, self.start_idx, self.end_idx, density) + else: + writes = reader.get_memory_region_writes_between(bp.address, bp.length, self.start_idx, self.end_idx, density) + self._idx_writes.extend(writes) + + # fetch memory access for all breakpoints + for bp in model.bp_access.values(): + if bp.length == 1: + reads, writes = reader.get_memory_accesses_between(bp.address, self.start_idx, self.end_idx, density) + else: + reads, writes = reader.get_memory_region_accesses_between(bp.address, bp.length, self.start_idx, self.end_idx, density) + self._idx_reads.extend(reads) + self._idx_writes.extend(writes) + + def _clamp_idx(self, idx): + """ + Clamp the given idx to the bounds of this trace view. + """ + if idx < self.start_idx: + return self.start_idx + elif idx >= self.end_idx: + return self.end_idx - 1 + return idx + + #------------------------------------------------------------------------- + # Drawing + #------------------------------------------------------------------------- + + def paintEvent(self, event): + """ + Qt overload of widget painting. + + TODO/FUTURE: I was planning to make this paint by layer, and only + re-paint dirty layers as necessary. but I think it's unecessary to + do at this time as I don't think we're pressed for perf. + """ + painter = QtGui.QPainter(self) + + # + # draw instructions / trace landscape + # + + self._draw_base() + painter.drawImage(0, 0, self._image_base) + + # + # draw accesses along the trace timeline + # + + self._draw_highlights() + painter.drawImage(0, 0, self._image_highlights) + + # + # draw user region selection over trace timeline + # + + self._draw_selection() + painter.drawImage(0, 0, self._image_selection) + + # + # draw border around trace timeline + # + + self._draw_border() + painter.drawImage(0, 0, self._image_border) + + # + # draw current trace position cursor + # + + self._draw_cursor() + painter.drawImage(0, 0, self._image_cursor) + + #painter.drawImage(0, 0, self._image_final) + + def _draw_base(self): + """ + Draw the trace visualization of executed code. + """ + + # + # NOTE: DO NOT REMOVE !!! Qt will CRASH if we do not explicitly delete + # these here (dangling internal pointer to device/image otherwise?!?) + # + + del self._painter_base + + self._image_base = QtGui.QImage(self.width(), self.height(), QtGui.QImage.Format_ARGB32) + self._image_base.fill(self.pctx.palette.trace_bedrock) + #self._image_base.fill(QtGui.QColor("red")) # NOTE/debug + self._painter_base = QtGui.QPainter(self._image_base) + + # redraw instructions + if self.cells_visible: + self._draw_code_cells(self._painter_base) + else: + self._draw_code_trace(self._painter_base) + + def _draw_code_trace(self, painter): + """ + Draw a 'zoomed out' trace visualization of executed code. + """ + dctx = disassembler[self.pctx] + viz_w, viz_h = self.viz_size + viz_x, viz_y = self.viz_pos + + for i in range(viz_h): + + # convert a y pixel in the viz region to an executed address + wid_y = viz_y + i + idx = self._pos2idx(wid_y) + + # + # since we can conciously set a trace visualization bounds bigger + # than the actual underlying trace, it is possible for the trace + # to not take up the entire available space. + # + # when we reach the 'end' of the trace, we obviously can stop + # drawing any sort of landscape for it! + # + + if idx >= self._last_trace_idx: + break + + # get the executed/code address for the current idx that will represent this line + address = self.reader.get_ip(idx) + rebased_address = self.reader.analysis.rebase_pointer(address) + + # select the color for instructions that can be viewed with Tenet + if dctx.is_mapped(rebased_address): + painter.setPen(self.pctx.palette.trace_instruction) + + # unexplorable parts of the trace are 'greyed' out (eg, not in IDB) + else: + painter.setPen(self.pctx.palette.trace_unmapped) + + # paint the current line + painter.drawLine(viz_x, wid_y, viz_w, wid_y) + + def _draw_code_cells(self, painter): + """ + Draw a 'zoomed in', cell-based, trace visualization of executed code. + """ + + # + # if there is no spacing between cells, that means they are going to + # be relatively small and have shared 'cell walls' (borders) + # + # we attempt to maximize contrast between border and cell color, while + # attempting to keep the tracebar color visually consistent + # + + # compute the color to use for the borders between cells + border_color = self.pctx.palette.trace_cell_wall + if self._cell_spacing < 0: + border_color = self.pctx.palette.trace_cell_wall_contrast + + # compute the color to use for the cell bodies + if self._cell_spacing < 0: + ratio = (self._cell_border / (self._cell_height - 1)) * 0.5 + lighten = 100 + int(ratio * 100) + cell_color = self.pctx.palette.trace_instruction.lighter(lighten) + #print(f"Lightened by {lighten}% (Border: {self._cell_border}, Body: {self._cell_height}") + else: + cell_color = self.pctx.palette.trace_instruction + + border_pen = QtGui.QPen(border_color, self._cell_border, QtCore.Qt.SolidLine) + painter.setPen(border_pen) + painter.setBrush(cell_color) + + viz_x, _ = self.viz_pos + viz_w, _ = self.viz_size + + # compute cell positioning info + x = viz_x + self._cell_border * -1 + w = viz_w + self._cell_border + h = self._cell_height + + dctx = disassembler[self.pctx] + + # draw each cell + border + for idx in range(self.start_idx, self._last_trace_idx): + + # get the executed/code address for the current idx that will represent this cell + address = self.reader.get_ip(idx) + rebased_address = self.reader.analysis.rebase_pointer(address) + + # select the color for instructions that can be viewed with Tenet + if dctx.is_mapped(rebased_address): + painter.setBrush(cell_color) + + # unexplorable parts of the trace are 'greyed' out (eg, not in IDB) + else: + painter.setBrush(self.pctx.palette.trace_unmapped) + + y = self._idx2pos(idx) + painter.drawRect(int(x), int(y), int(w), int(h)) + + def _draw_highlights(self): + """ + Draw active event highlights (mem access, breakpoints) for the trace visualization. + """ + + # + # NOTE: DO NOT REMOVE !!! Qt will CRASH if we do not explicitly delete + # these here (dangling internal pointer to device/image otherwise?!?) + # + + del self._painter_highlights + + self._image_highlights = QtGui.QImage(self.width(), self.height(), QtGui.QImage.Format_ARGB32) + self._image_highlights.fill(QtCore.Qt.transparent) + self._painter_highlights = QtGui.QPainter(self._image_highlights) + + if self.cells_visible: + self._draw_highlights_cells(self._painter_highlights) + else: + self._draw_highlights_trace(self._painter_highlights) + + def _draw_highlights_cells(self, painter): + """ + Draw cell-based event highlights. + """ + viz_w, _ = self.viz_size + viz_x, _ = self.viz_pos + + access_sets = \ + [ + (self._idx_reads, self.pctx.palette.mem_read_bg), + (self._idx_writes, self.pctx.palette.mem_write_bg), + (self._idx_executions, self.pctx.palette.breakpoint), + ] + + painter.setPen(QtCore.Qt.NoPen) + + h = self._cell_height - self._cell_border + + for entries, cell_color in access_sets: + painter.setBrush(cell_color) + + for idx in entries: + + # skip entries that fall outside the visible zoom + if not(self.start_idx <= idx < self.end_idx): + continue + + # slight tweak of y because we are only drawing a highlighted + # cell body without borders + y = self._idx2pos(idx) + self._cell_border + + # draw cell body + painter.drawRect(int(viz_x), int(y), int(viz_w), int(h)) + + def _draw_highlights_trace(self, painter): + """ + Draw trace-based event highlights. + """ + viz_w, _ = self.viz_size + viz_x, _ = self.viz_pos + + access_sets = \ + [ + (self._idx_reads, self.pctx.palette.mem_read_bg), + (self._idx_writes, self.pctx.palette.mem_write_bg), + (self._idx_executions, self.pctx.palette.breakpoint), + ] + + for entries, color in access_sets: + painter.setPen(color) + + for idx in entries: + + # skip entries that fall outside the visible zoom + if not(self.start_idx <= idx < self.end_idx): + continue + + y = self._idx2pos(idx) + painter.drawLine(viz_x, y, viz_w, y) + + def _draw_cursor(self): + """ + Draw the user cursor / current position in the trace. + """ + path = QtGui.QPainterPath() + + size = 13 + assert size % 2, "Cursor triangle size must be odd" + + del self._painter_cursor + self._image_cursor = QtGui.QImage(self.width(), self.height(), QtGui.QImage.Format_ARGB32) + self._image_cursor.fill(QtCore.Qt.transparent) + self._painter_cursor = QtGui.QPainter(self._image_cursor) + + # compute the y coordinate / line to center the user cursor around + cursor_y = self._idx2pos(self.reader.idx) + draw_reader_cursor = bool(cursor_y != INVALID_IDX) + + if self.cells_visible: + cell_y = cursor_y + self._cell_border + cell_body_height = self._cell_height - self._cell_border + cursor_y += self._cell_height/2 + + # the top point of the triangle + top_x = 0 + top_y = cursor_y - (size // 2) # vertically align the triangle so the tip matches the cross section + + # bottom point of the triangle + bottom_x = top_x + bottom_y = top_y + size - 1 + + # the 'tip' of the triangle pointing into towards the center of the trace + tip_x = top_x + (size // 2) + tip_y = top_y + (size // 2) + + # start drawing from the 'top' of the triangle + path.moveTo(top_x, top_y) + + # generate the triangle path / shape + path.lineTo(bottom_x, bottom_y) + path.lineTo(tip_x, tip_y) + path.lineTo(top_x, top_y) + + viz_x, _ = self.viz_pos + viz_w, _ = self.viz_size + + # draw the user cursor in cell mode + if self.cells_visible: + + # normal fixed / current reader cursor + self._painter_cursor.setPen(QtCore.Qt.NoPen) + self._painter_cursor.setBrush(self.pctx.palette.trace_cursor_highlight) + + if draw_reader_cursor: + self._painter_cursor.drawRect(int(viz_x), int(cell_y), int(viz_w), int(cell_body_height)) + + # cursor hover highlighting an event + if self._hovered_idx != INVALID_IDX: + hovered_y = self._idx2pos(self._hovered_idx) + hovered_cell_y = hovered_y + self._cell_border + self._painter_cursor.drawRect(int(viz_x), int(hovered_cell_y), int(viz_w), int(cell_body_height)) + + # draw the user cursor in dense/landscape mode + else: + self._painter_cursor.setPen(self._pen_cursor) + + # normal fixed / current reader cursor + if draw_reader_cursor: + self._painter_cursor.drawLine(viz_x, cursor_y, viz_w, cursor_y) + + # cursor hover highlighting an event + if self._hovered_idx != INVALID_IDX: + hovered_y = self._idx2pos(self._hovered_idx) + self._painter_cursor.drawLine(viz_x, hovered_y, viz_w, hovered_y) + + if not draw_reader_cursor: + return + + # paint the defined triangle + self._painter_cursor.setPen(self.pctx.palette.trace_cursor_border) + self._painter_cursor.setBrush(self.pctx.palette.trace_cursor) + self._painter_cursor.drawPath(path) + + def _draw_selection(self): + """ + Draw a region selection rect. + """ + + # + # NOTE: DO NOT REMOVE !!! Qt will CRASH if we do not explicitly delete + # these here (dangling internal pointer to device/image otherwise?!?) + # + + del self._painter_selection + + viz_w, viz_h = self.viz_size + self._image_selection = QtGui.QImage(self.width(), self.height(), QtGui.QImage.Format_ARGB32) + self._image_selection.fill(QtCore.Qt.transparent) + self._painter_selection = QtGui.QPainter(self._image_selection) + + # active / on-going selection event + if self._idx_pending_selection_start != INVALID_IDX: + start_idx = self._idx_pending_selection_start + end_idx = self._idx_pending_selection_end + + # fixed / committed selection + elif self._idx_selection_start != INVALID_IDX: + start_idx = self._idx_selection_start + end_idx = self._idx_selection_end + + # no region selection, nothing to do... + else: + return + + start_idx = self._clamp_idx(start_idx) + end_idx = self._clamp_idx(end_idx) + + # nothing to draw + if start_idx == end_idx: + return + + start_y = self._idx2pos(start_idx) + end_y = self._idx2pos(end_idx) + + self._painter_selection.setBrush(self._brush_selection) + self._painter_selection.setPen(self._pen_selection) + + # TODO/FUTURE: real border math + viz_x, viz_y = self.viz_pos + + x = viz_x + y = start_y + w = viz_w + h = end_y - start_y + + # draw the screen door / selection rect + self._painter_selection.drawRect(int(x), int(y), int(w), int(h)) + + def _draw_border(self): + """ + Draw the border around the trace timeline. + """ + wid_w, wid_h = self.width(), self.height() + + # + # NOTE: DO NOT REMOVE !!! Qt will CRASH if we do not explicitly delete + # these here (dangling internal pointer to device/image otherwise?!?) + # + + del self._painter_border + + self._image_border = QtGui.QImage(wid_w, wid_h, QtGui.QImage.Format_ARGB32) + self._image_border.fill(QtCore.Qt.transparent) + self._painter_border = QtGui.QPainter(self._image_border) + + color = self.pctx.palette.trace_border + #color = QtGui.QColor("red") # NOTE: for dev/debug testing + border_pen = QtGui.QPen(color, self._trace_border, QtCore.Qt.SolidLine) + self._painter_border.setPen(border_pen) + + w = wid_w - self._trace_border + h = wid_h - self._trace_border + + # draw the border around the tracebar using a blank rect + stroke (border) + self._painter_border.drawRect(0, 0, w, h) + + #---------------------------------------------------------------------- + # Callbacks + #---------------------------------------------------------------------- + + def selection_changed(self, callback): + """ + Subscribe a callback for a trace slice selection change event. + """ + register_callback(self._selection_changed_callbacks, callback) + + def _notify_selection_changed(self, start_idx, end_idx): + """ + Notify listeners of a trace slice selection change event. + """ + notify_callback(self._selection_changed_callbacks, start_idx, end_idx) + +#----------------------------------------------------------------------------- +# Trace View +#----------------------------------------------------------------------------- + +class TraceView(QtWidgets.QWidget): + + def __init__(self, pctx, parent=None): + super(TraceView, self).__init__(parent) + self.pctx = pctx + self._init_ui() + + def _init_ui(self): + self._init_bars() + self._init_ctx_menu() + + def attach_reader(self, reader): + self.trace_global.attach_reader(reader) + self.trace_local.attach_reader(reader) + self.trace_local.hide() + + def detach_reader(self): + self.trace_global.reset() + self.trace_local.reset() + self.trace_local.hide() + + def _init_bars(self): + self.trace_local = TraceBar(self.pctx, zoom=True) + self.trace_global = TraceBar(self.pctx) + + # connect the local view to follow the global selection + self.trace_global.selection_changed(self.trace_local._zoom_selection_changed) + self.trace_local.selection_changed(self.trace_global._global_selection_changed) + + # connect other signals + self.pctx.breakpoints.model.breakpoints_changed(self.trace_global._breakpoints_changed) + self.pctx.breakpoints.model.breakpoints_changed(self.trace_local._breakpoints_changed) + + # hide the zoom bar by default + self.trace_local.hide() + + # setup the layout and spacing for the tracebar + hbox = QtWidgets.QHBoxLayout(self) + hbox.setContentsMargins(3, 3, 3, 3) + hbox.setSpacing(3) + + # add the layout container / mechanism to the toolbar + hbox.addWidget(self.trace_local) + hbox.addWidget(self.trace_global) + + self.setLayout(hbox) + + def _init_ctx_menu(self): + """ + Initialize the right click context menu actions. + """ + self._menu = QtWidgets.QMenu() + + # create actions to show in the context menu + self._action_clear = self._menu.addAction("Clear all breakpoints") + self._menu.addSeparator() + self._action_load = self._menu.addAction("Load new trace") + self._action_close = self._menu.addAction("Close trace") + + # install the right click context menu + self.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) + self.customContextMenuRequested.connect(self._ctx_menu_handler) + + #-------------------------------------------------------------------------- + # Signals + #-------------------------------------------------------------------------- + + def _ctx_menu_handler(self, position): + """ + Handle a right click event (populate/show context menu). + """ + action = self._menu.exec_(self.mapToGlobal(position)) + if action == self._action_load: + self.pctx.interactive_load_trace(True) + elif action == self._action_close: + self.pctx.close_trace() + elif action == self._action_clear: + self.pctx.breakpoints.clear_breakpoints() + + def update_from_model(self): + for bar in self.model.tracebars.values()[::-1]: + self.hbox.addWidget(bar) + + # this will insert the children (tracebars) and apply spacing as appropriate + self.bar_container.setLayout(self.hbox) + +#----------------------------------------------------------------------------- +# Dockable Trace Visualization +#----------------------------------------------------------------------------- + +class TraceDock(QtWidgets.QToolBar): + """ + A Qt 'Toolbar' to house the TraceBar visualizations. + + We use a Toolbar explicitly because they are given unique docking regions + around the QMainWindow in Qt-based applications. This allows us to pin + the visualizations to areas where they will not be dist + """ + def __init__(self, pctx, parent=None): + super(TraceDock, self).__init__(parent) + self.pctx = pctx + self.view = TraceView(pctx, self) + self.setMovable(False) + self.setContentsMargins(0, 0, 0, 0) + self.addWidget(self.view) + + def attach_reader(self, reader): + self.view.attach_reader(reader) + + def detach_reader(self): + self.view.detach_reader() \ No newline at end of file diff --git a/plugins_sogen-support/tenet/util/__init__.py b/plugins_sogen-support/tenet/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins_sogen-support/tenet/util/__pycache__/__init__.cpython-311.pyc b/plugins_sogen-support/tenet/util/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..5cec821 Binary files /dev/null and b/plugins_sogen-support/tenet/util/__pycache__/__init__.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/util/__pycache__/log.cpython-311.pyc b/plugins_sogen-support/tenet/util/__pycache__/log.cpython-311.pyc new file mode 100644 index 0000000..73ed335 Binary files /dev/null and b/plugins_sogen-support/tenet/util/__pycache__/log.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/util/__pycache__/misc.cpython-311.pyc b/plugins_sogen-support/tenet/util/__pycache__/misc.cpython-311.pyc new file mode 100644 index 0000000..bd5568d Binary files /dev/null and b/plugins_sogen-support/tenet/util/__pycache__/misc.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/util/__pycache__/update.cpython-311.pyc b/plugins_sogen-support/tenet/util/__pycache__/update.cpython-311.pyc new file mode 100644 index 0000000..504f782 Binary files /dev/null and b/plugins_sogen-support/tenet/util/__pycache__/update.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/util/debug.py b/plugins_sogen-support/tenet/util/debug.py new file mode 100644 index 0000000..22d5aa2 --- /dev/null +++ b/plugins_sogen-support/tenet/util/debug.py @@ -0,0 +1,19 @@ +import time + +#------------------------------------------------------------------------------ +# Debug / Profiling Helpers +#------------------------------------------------------------------------------ + +def timeit(method): + def timed(*args, **kw): + ts = time.time() + result = method(*args, **kw) + te = time.time() + if 'log_time' in kw: + name = kw.get('log_name', method.__name__.upper()) + kw['log_time'][name] = int((te - ts) * 1000) + else: + print('%r %2.2f ms' % \ + (method.__name__, (te - ts) * 1000)) + return result + return timed \ No newline at end of file diff --git a/plugins_sogen-support/tenet/util/log.py b/plugins_sogen-support/tenet/util/log.py new file mode 100644 index 0000000..4c4a3ed --- /dev/null +++ b/plugins_sogen-support/tenet/util/log.py @@ -0,0 +1,153 @@ +import os +import sys +import logging + +from .misc import makedirs, is_plugin_dev +from ..integration.api import disassembler + +#------------------------------------------------------------------------------ +# Log / Print helpers +#------------------------------------------------------------------------------ + +def pmsg(message): + """ + Print a 'plugin message' to the disassembler output window. + """ + + # prefix the message + prefix_message = "[Tenet] %s" % message + + # only print to disassembler if its output window is alive + if disassembler.is_msg_inited(): + disassembler.message(prefix_message) + else: + logger.info(message) + +def get_log_dir(): + """ + Return the plugin log directory. + """ + log_directory = os.path.join( + disassembler.get_disassembler_user_directory(), + "tenet_logs" + ) + return log_directory + +def logging_started(): + """ + Check if logging has been started. + """ + return 'logger' in globals() + +#------------------------------------------------------------------------------ +# Logger Proxy +#------------------------------------------------------------------------------ + +class LoggerProxy(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + def __init__(self, logger, stream, log_level=logging.INFO): + self._logger = logger + self._log_level = log_level + self._stream = stream + + def write(self, buf): + for line in buf.rstrip().splitlines(): + self._logger.log(self._log_level, line.rstrip()) + if self._stream: + self._stream.write(buf) + + def flush(self): + pass + + def isatty(self): + pass + +#------------------------------------------------------------------------------ +# Initialize Logging +#------------------------------------------------------------------------------ + +MAX_LOGS = 10 +def cleanup_log_directory(log_directory): + """ + Retain only the last 15 logs. + """ + filetimes = {} + + # build a map of all the files in the directory, and their last modified time + for log_name in os.listdir(log_directory): + filepath = os.path.join(log_directory, log_name) + if os.path.isfile(filepath): + filetimes[os.path.getmtime(filepath)] = filepath + + # get the filetimes and check if there's enough to warrant cleanup + times = list(filetimes.keys()) + if len(times) < MAX_LOGS: + return + + logger.debug("Cleaning logs directory") + + # discard the newest 15 logs + times.sort(reverse=True) + times = times[MAX_LOGS:] + + # loop through the remaining older logs, and delete them + for log_time in times: + try: + os.remove(filetimes[log_time]) + except Exception as e: + logger.error("Failed to delete log %s" % filetimes[log_time]) + logger.error(e) + +def start_logging(): + global logger + + # create the plugin logger + logger = logging.getLogger("Tenet") + + # + # only enable logging if the plugin-specific environment variable is + # present. otherwive we return a stub logger to sinkhole messages. + # + + if not is_plugin_dev(): + logger.disabled = True + return logger + + # create a directory for plugin logs if it does not exist + log_dir = get_log_dir() + try: + makedirs(log_dir) + except Exception as e: + logger.disabled = True + return logger + + # construct the full log path + log_path = os.path.join(log_dir, "tenet.%s.log" % os.getpid()) + + # config the logger + logging.basicConfig( + filename=log_path, + format='%(asctime)s | %(name)28s | %(levelname)7s: %(message)s', + datefmt='%m-%d-%Y %H:%M:%S', + level=logging.DEBUG + ) + + # proxy STDOUT/STDERR to the log files too + stdout_logger = logging.getLogger('Tenet.STDOUT') + stderr_logger = logging.getLogger('Tenet.STDERR') + sys.stdout = LoggerProxy(stdout_logger, sys.stdout, logging.INFO) + sys.stderr = LoggerProxy(stderr_logger, sys.stderr, logging.ERROR) + + # limit the number of logs we keep + cleanup_log_directory(log_dir) + + return logger + +#------------------------------------------------------------------------------ +# Log Helpers +#------------------------------------------------------------------------------ + +def log_config_warning(self, logger, section, field): + logger.warning("Config missing field '%s' in section '%s", field, section) diff --git a/plugins_sogen-support/tenet/util/misc.py b/plugins_sogen-support/tenet/util/misc.py new file mode 100644 index 0000000..efe59ff --- /dev/null +++ b/plugins_sogen-support/tenet/util/misc.py @@ -0,0 +1,174 @@ +import os +import errno +import struct +import weakref +import threading + +#------------------------------------------------------------------------------ +# Plugin Util +#------------------------------------------------------------------------------ + +PLUGIN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + +def is_plugin_dev(): + """ + Return True if the plugin is in developer mode. + """ + return bool(os.getenv("TENET_DEV")) + +def plugin_resource(resource_name): + """ + Return the full path for a given plugin resource file. + """ + return os.path.join( + PLUGIN_PATH, + "ui", + "resources", + resource_name + ) + +#------------------------------------------------------------------------------ +# Thread Util +#------------------------------------------------------------------------------ + +def is_mainthread(): + """ + Return a bool that indicates if this is the main application thread. + """ + return isinstance(threading.current_thread(), threading._MainThread) + +def assert_mainthread(f): + """ + A sanity decorator to ensure that a function is always called from the main thread. + """ + def wrapper(*args, **kwargs): + assert is_mainthread() + return f(*args, **kwargs) + return wrapper + +def assert_async(f): + """ + A sanity decorator to ensure that a function is never called from the main thread. + """ + def wrapper(*args, **kwargs): + assert not is_mainthread() + return f(*args, **kwargs) + return wrapper + +#----------------------------------------------------------------------------- +# Python Utils +#----------------------------------------------------------------------------- + +def chunks(lst, n): + """ + Yield successive n-sized chunks from lst. + """ + for i in range(0, len(lst), n): + yield lst[i:i + n] + +def hexdump(data): + """ + Return an ascii hexdump of the given data. + """ + return '\n'.join([' '.join([f"{x:02X}" for x in chunk]) for chunk in chunks(data, 16)]) + +def makedirs(path, exists_ok=True): + """ + Create directories along a fully qualified path. + """ + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise e + if not exists_ok: + raise e + +def swap_rgb(i): + """ + Swap a 32bit RRGGBB (integer) to BBGGRR. + """ + return struct.unpack("I", i))[0] >> 8 + +#------------------------------------------------------------------------------ +# Python Callback / Signals +#------------------------------------------------------------------------------ + +def register_callback(callback_list, callback): + """ + Register a callable function to the given callback_list. + + Adapted from http://stackoverflow.com/a/21941670 + """ + + # create a weakref callback to an object method + try: + callback_ref = weakref.ref(callback.__func__), weakref.ref(callback.__self__) + + # create a wweakref callback to a stand alone function + except AttributeError: + callback_ref = weakref.ref(callback), None + + # 'register' the callback + callback_list.append(callback_ref) + +def notify_callback(callback_list, *args): + """ + Notify the given list of registered callbacks of an event. + + The given list (callback_list) is a list of weakref'd callables + registered through the register_callback() function. To notify the + callbacks of an event, this function will simply loop through the list + and call them. + + This routine self-heals by removing dead callbacks for deleted objects as + it encounters them. + + Adapted from http://stackoverflow.com/a/21941670 + """ + cleanup = [] + + # + # loop through all the registered callbacks in the given callback_list, + # notifying active callbacks, and removing dead ones. + # + + for callback_ref in callback_list: + callback, obj_ref = callback_ref[0](), callback_ref[1] + + # + # if the callback is an instance method, deference the instance + # (an object) first to check that it is still alive + # + + if obj_ref: + obj = obj_ref() + + # if the object instance is gone, mark this callback for cleanup + if obj is None: + cleanup.append(callback_ref) + continue + + # call the object instance callback + try: + callback(obj, *args) + + # assume a Qt cleanup/deletion occurred + except RuntimeError as e: + cleanup.append(callback_ref) + continue + + # if the callback is a static method... + else: + + # if the static method is deleted, mark this callback for cleanup + if callback is None: + cleanup.append(callback_ref) + continue + + # call the static callback + callback(*args) + + # remove the deleted callbacks + for callback_ref in cleanup: + callback_list.remove(callback_ref) \ No newline at end of file diff --git a/plugins_sogen-support/tenet/util/qt/__init__.py b/plugins_sogen-support/tenet/util/qt/__init__.py new file mode 100644 index 0000000..810ea60 --- /dev/null +++ b/plugins_sogen-support/tenet/util/qt/__init__.py @@ -0,0 +1,5 @@ +from .shim import * + +if QT_AVAILABLE: + from .util import * + from .waitbox import WaitBox \ No newline at end of file diff --git a/plugins_sogen-support/tenet/util/qt/__pycache__/__init__.cpython-311.pyc b/plugins_sogen-support/tenet/util/qt/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..8ee3f11 Binary files /dev/null and b/plugins_sogen-support/tenet/util/qt/__pycache__/__init__.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/util/qt/__pycache__/shim.cpython-311.pyc b/plugins_sogen-support/tenet/util/qt/__pycache__/shim.cpython-311.pyc new file mode 100644 index 0000000..657c6c1 Binary files /dev/null and b/plugins_sogen-support/tenet/util/qt/__pycache__/shim.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/util/qt/__pycache__/util.cpython-311.pyc b/plugins_sogen-support/tenet/util/qt/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000..bb9c2c9 Binary files /dev/null and b/plugins_sogen-support/tenet/util/qt/__pycache__/util.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/util/qt/__pycache__/waitbox.cpython-311.pyc b/plugins_sogen-support/tenet/util/qt/__pycache__/waitbox.cpython-311.pyc new file mode 100644 index 0000000..8bcb454 Binary files /dev/null and b/plugins_sogen-support/tenet/util/qt/__pycache__/waitbox.cpython-311.pyc differ diff --git a/plugins_sogen-support/tenet/util/qt/shim.py b/plugins_sogen-support/tenet/util/qt/shim.py new file mode 100644 index 0000000..dd4793c --- /dev/null +++ b/plugins_sogen-support/tenet/util/qt/shim.py @@ -0,0 +1,66 @@ + +# +# this global is used to indicate whether Qt bindings for python are present +# and available for use by Lighthouse. +# + +QT_AVAILABLE = False + +#------------------------------------------------------------------------------ +# PyQt5 <--> PySide2 Compatibility +#------------------------------------------------------------------------------ +# +# we use this file to shim/re-alias a few Qt API's to ensure compatibility +# between the popular Qt frameworks. these shims serve to reduce the number +# of compatibility checks in the plugin code that consumes them. +# +# this file was critical for retaining compatibility with Qt4 frameworks +# used by IDA 6.8/6.95, but it less important now. support for Qt 4 and +# older versions of IDA will be deprecated in Lighthouse v0.9.0 +# + +USING_PYQT5 = False +USING_PYSIDE2 = False + +#------------------------------------------------------------------------------ +# PyQt5 Compatibility +#------------------------------------------------------------------------------ + +# attempt to load PyQt5 +if QT_AVAILABLE == False: + try: + import PyQt5.QtGui as QtGui + import PyQt5.QtCore as QtCore + import PyQt5.QtWidgets as QtWidgets + from PyQt5 import sip + + # importing went okay, PyQt5 must be available for use + QT_AVAILABLE = True + USING_PYQT5 = True + + # import failed, PyQt5 is not available + except ImportError: + pass + +#------------------------------------------------------------------------------ +# PySide2 Compatibility +#------------------------------------------------------------------------------ + +# if PyQt5 did not import, try to load PySide +if QT_AVAILABLE == False: + try: + import PySide2.QtGui as QtGui + import PySide2.QtCore as QtCore + import PySide2.QtWidgets as QtWidgets + + # alias for less PySide2 <--> PyQt5 shimming + QtCore.pyqtSignal = QtCore.Signal + QtCore.pyqtSlot = QtCore.Slot + + # importing went okay, PySide must be available for use + QT_AVAILABLE = True + USING_PYSIDE2 = True + + # import failed. No Qt / UI bindings available... + except ImportError: + pass \ No newline at end of file diff --git a/plugins_sogen-support/tenet/util/qt/util.py b/plugins_sogen-support/tenet/util/qt/util.py new file mode 100644 index 0000000..758f41a --- /dev/null +++ b/plugins_sogen-support/tenet/util/qt/util.py @@ -0,0 +1,86 @@ +import sys +import time + +from .shim import * + +#------------------------------------------------------------------------------ +# Qt Fonts +#------------------------------------------------------------------------------ + +def MonospaceFont(): + """ + Convenience alias for creating a monospace Qt font object. + """ + font = QtGui.QFont("Courier New") + font.setStyleHint(QtGui.QFont.TypeWriter) + return font + +#------------------------------------------------------------------------------ +# Qt Util +#------------------------------------------------------------------------------ + +def copy_to_clipboard(data): + """ + Copy the given data (a string) to the system clipboard. + """ + cb = QtWidgets.QApplication.clipboard() + cb.clear(mode=cb.Clipboard) + cb.setText(data, mode=cb.Clipboard) + +def flush_qt_events(): + """ + Flush the Qt event pipeline. + """ + app = QtCore.QCoreApplication.instance() + app.processEvents() + +def focus_window(): + """ + Lame helper function to help with dev/debug. + """ + mb = QtWidgets.QMessageBox(get_qmainwindow()) + mb.setText("Click to take focus...") + mb.setStandardButtons(QtWidgets.QMessageBox.Ok) + button = mb.button(QtWidgets.QMessageBox.Ok) + mb.exec_() + +def get_dpi_scale(): + """ + Get a DPI-afflicted value useful for consistent UI scaling. + """ + font = MonospaceFont() + font.setPointSize(normalize_font(120)) + fm = QtGui.QFontMetricsF(font) + + # xHeight is expected to be 40.0 at normal DPI + return fm.height() / 173.0 + +def normalize_font(font_size): + """ + Normalize the given font size based on the system DPI. + """ + if sys.platform == "darwin": # macos is lame + return font_size + 2 + return font_size + +def get_qmainwindow(): + """ + Get the QMainWindow instance for the current Qt runtime. + """ + app = QtWidgets.QApplication.instance() + return [x for x in app.allWidgets() if x.__class__ is QtWidgets.QMainWindow][0] + +def compute_color_on_gradient(percent, color1, color2): + """ + Compute the color specified by a percent between two colors. + """ + r1, g1, b1, _ = color1.getRgb() + r2, g2, b2, _ = color2.getRgb() + + # compute the new color across the gradient of color1 -> color 2 + r = r1 + percent * (r2 - r1) + g = g1 + percent * (g2 - g1) + b = b1 + percent * (b2 - b1) + + # return the new color + return QtGui.QColor(r,g,b) \ No newline at end of file diff --git a/plugins_sogen-support/tenet/util/qt/waitbox.py b/plugins_sogen-support/tenet/util/qt/waitbox.py new file mode 100644 index 0000000..586dcc8 --- /dev/null +++ b/plugins_sogen-support/tenet/util/qt/waitbox.py @@ -0,0 +1,102 @@ +from .shim import * +from .util import get_dpi_scale + +import logging +logger = logging.getLogger("Tenet.Qt.WaitBox") + +#-------------------------------------------------------------------------- +# Qt WaitBox +#-------------------------------------------------------------------------- + +class WaitBox(QtWidgets.QDialog): + """ + A Generic Qt WaitBox Dialog. + """ + + def __init__(self, text, title="Please wait...", abort=None): + super(WaitBox, self).__init__() + + # dialog text & window title + self._text = text + self._title = title + + # abort routine (optional) + self._abort = abort + + # initialize the dialog UI + self._ui_init() + + def set_text(self, text): + """ + Change the waitbox text. + """ + self._text = text + self._text_label.setText(text) + qta = QtCore.QCoreApplication.instance() + qta.processEvents() + + def show(self, modal=True): + self.setModal(modal) + result = super(WaitBox, self).show() + qta = QtCore.QCoreApplication.instance() + qta.processEvents() + + #-------------------------------------------------------------------------- + # Initialization - UI + #-------------------------------------------------------------------------- + + def _ui_init(self): + """ + Initialize UI elements. + """ + self.setWindowFlags( + self.windowFlags() & ~QtCore.Qt.WindowContextHelpButtonHint + ) + self.setWindowFlags( + self.windowFlags() | QtCore.Qt.MSWindowsFixedSizeDialogHint + ) + self.setWindowFlags( + self.windowFlags() & ~QtCore.Qt.WindowCloseButtonHint + ) + + # configure the main widget / form + self.setSizeGripEnabled(False) + self.setModal(True) + self._dpi_scale = get_dpi_scale()*5.0 + + # initialize abort button + self._abort_button = QtWidgets.QPushButton("Cancel") + + # layout the populated UI just before showing it + self._ui_layout() + + def _ui_layout(self): + """ + Layout the major UI elements of the widget. + """ + self.setWindowTitle(self._title) + self._text_label = QtWidgets.QLabel(self._text) + self._text_label.setAlignment(QtCore.Qt.AlignHCenter) + + # vertical layout (whole widget) + v_layout = QtWidgets.QVBoxLayout() + v_layout.setAlignment(QtCore.Qt.AlignCenter) + v_layout.addWidget(self._text_label) + if self._abort: + self._abort_button.clicked.connect(self._abort) + v_layout.addWidget(self._abort_button) + + v_layout.setSpacing(int(self._dpi_scale*3)) + v_layout.setContentsMargins( + int(self._dpi_scale*5), + int(self._dpi_scale), + int(self._dpi_scale*5), + int(self._dpi_scale) + ) + + # scale widget dimensions based on DPI + height = int(self._dpi_scale * 15) + self.setMinimumHeight(height) + + # compute the dialog layout + self.setLayout(v_layout) \ No newline at end of file diff --git a/plugins_sogen-support/tenet/util/rebase.py b/plugins_sogen-support/tenet/util/rebase.py new file mode 100644 index 0000000..802288a --- /dev/null +++ b/plugins_sogen-support/tenet/util/rebase.py @@ -0,0 +1,99 @@ +import ida_segment +import ida_nalt +import ida_auto +import ida_kernwin + +from tenet.util.log import pmsg + +def rebase_database_manually(delta, new_base_address): + """ + Manually rebase the database by moving each segment individually. + This is a robust fallback for when rebase_program fails. + """ + + + # 1. Collect segment information + segments_info = [] + for i in range(ida_segment.get_segm_qty()): + seg = ida_segment.getnseg(i) + if seg: + segments_info.append({ + 'seg': seg, + 'name': ida_segment.get_segm_name(seg), + 'start_ea': seg.start_ea, + 'new_start': seg.start_ea + delta + }) + + # 2. Sort segments to prevent overlaps during the move. + segments_info.sort(key=lambda s: s['start_ea'], reverse=delta > 0) + + # 3. Move each segment + for seg_info in segments_info: + seg = seg_info['seg'] + new_start = seg_info['new_start'] + seg_name = seg_info['name'] + + # Try to move with flags=2 first (preserves info) + if not ida_segment.move_segm(seg, new_start, 2): + # Check if it moved despite the error code (IDA API quirk) + updated_seg = ida_segment.getseg(new_start) + if not (updated_seg and ida_segment.get_segm_name(updated_seg) == seg_name): + ida_kernwin.warning(f"Manual rebase failed: could not move segment '{seg_name}'.") + return False + + #pmsg("All segments moved successfully in manual rebase.") + return True + +def rebase_database(new_base_address): + """ + Rebase the program using a two-phase approach: + 1. Attempt a fast, simple rebase with rebase_program. + 2. If that fails, use a more robust manual segment-moving algorithm. + """ + current_base = ida_nalt.get_imagebase() + + if current_base == new_base_address: + pmsg("Database is already based at the target address.") + return True + + delta = new_base_address - current_base + + if not ida_kernwin.ask_yn( + ida_kernwin.ASKBTN_YES, + f"A new base address (0x{new_base_address:X}) was found in the trace log. " + f"Would you like to rebase the database from 0x{current_base:X}?\n\n" + "(This is a permanent operation)" + ): + pmsg("Rebase operation cancelled by user.") + return False + + # --- Phase 1: Fast Rebase --- + #pmsg("Phase 1: Attempting fast rebase with rebase_program...") + flags = 4 # MSF_FIXONCE: Fix up the program connections, etc. + if ida_segment.rebase_program(delta, flags) == 0: + pass + #pmsg("rebase_program returned error, but checking if it worked anyway...") + + # --- Verification --- + if ida_nalt.get_imagebase() == new_base_address: + #pmsg("Phase 1 successful. Rerunning analysis...") + ida_auto.auto_wait() + return True + + #pmsg("Phase 1 failed. The database imagebase was not changed.") + + # --- Phase 2: Manual Fallback Rebase --- + #pmsg("Phase 2: Falling back to manual segment-by-segment rebase...") + if not rebase_database_manually(delta, new_base_address): + ida_kernwin.warning("Manual rebase also failed. The database may be in an inconsistent state.") + return False + + ida_nalt.set_imagebase(new_base_address) + #pmsg("Rerunning analysis after manual rebase...") + ida_auto.auto_wait() + + if ida_nalt.get_imagebase() == new_base_address: + return True + + ida_kernwin.warning(f"Rebase failed. Current base 0x{ida_nalt.get_imagebase():X} does not match target 0x{new_base_address:X}.") + return False \ No newline at end of file diff --git a/plugins_sogen-support/tenet/util/update.py b/plugins_sogen-support/tenet/util/update.py new file mode 100644 index 0000000..195089c --- /dev/null +++ b/plugins_sogen-support/tenet/util/update.py @@ -0,0 +1,59 @@ +import re +import json +import logging +import threading + +from urllib.request import urlopen + +logger = logging.getLogger("Tenet.Util.Update") + +#------------------------------------------------------------------------------ +# Update Checking +#------------------------------------------------------------------------------ + +UPDATE_URL = "https://api.github.com/repos/gaasedelen/tenet/releases/latest" + +def check_for_update(current_version, callback): + """ + Perform a plugin update check. + """ + update_thread = threading.Thread( + target=async_update_check, + args=(current_version, callback,), + name="UpdateChecker" + ) + update_thread.start() + +def async_update_check(current_version, callback): + """ + An async worker thread to check for an plugin update. + """ + logger.debug("Checking for update...") + + try: + response = urlopen(UPDATE_URL, timeout=5.0) + html = response.read() + info = json.loads(html) + remote_version = info["tag_name"] + except Exception: + logger.debug(" - Failed to reach GitHub for update check...") + return + + # convert vesrion #'s to integer for easy compare... + version_remote = int(''.join(re.findall('\d+', remote_version))) + version_local = int(''.join(re.findall('\d+', current_version))) + + # no updates available... + logger.debug(" - Local: 'v%s' vs Remote: '%s'" % (current_version, remote_version)) + if version_local >= version_remote: + logger.debug(" - No update needed...") + return + + # notify the user if an update is available + update_message = "An update is available for Tenet!\n\n" \ + " - Latest Version: %s\n" % (remote_version) + \ + " - Current Version: v%s\n\n" % (current_version) + \ + "Please go download the update from GitHub." + + callback(update_message) + diff --git a/plugins_sogen-support/tenet_plugin.py b/plugins_sogen-support/tenet_plugin.py new file mode 100644 index 0000000..e1896ac --- /dev/null +++ b/plugins_sogen-support/tenet_plugin.py @@ -0,0 +1,23 @@ +from tenet.util.log import logging_started, start_logging +from tenet.integration.api import disassembler + +if not logging_started(): + logger = start_logging() + +#------------------------------------------------------------------------------ +# Disassembler Agnonstic Plugin Loader +#------------------------------------------------------------------------------ + +logger.debug("Resolving disassembler platform for Tenet...") + +if disassembler.headless: + logger.info("Disassembler '%s' is running headlessly" % disassembler.NAME) + logger.info(" - Tenet is not supported in headless modes (yet!)") + +elif disassembler.NAME == "IDA": + logger.info("Selecting IDA loader...") + from tenet.integration.ida_loader import * + +else: + raise NotImplementedError("DISASSEMBLER-SPECIFIC SHIM MISSING") + diff --git a/tracers/pin/pintenet.cpp b/tracers/pin/pintenet.cpp index 9291844..a85be39 100644 --- a/tracers/pin/pintenet.cpp +++ b/tracers/pin/pintenet.cpp @@ -1,432 +1,445 @@ -// -// pintenet.cpp, a Proof-of-Concept Tenet Tracer -// -// -- by Patrick Biernat & Markus Gaasedelen -// @ RET2 Systems, Inc. -// -// Adaptions from the CodeCoverage pin tool by Agustin Gianni as -// contributed to Lighthouse: https://github.com/gaasedelen/lighthouse -// - -#include -#include -#include - -#include "pin.H" -#include "ImageManager.h" - -using std::ofstream; - -ofstream* g_log; - -#ifdef __i386__ -#define PC "eip" -#else -#define PC "rip" -#endif - -// -// Tool Arguments -// - -static KNOB KnobModuleWhitelist(KNOB_MODE_APPEND, "pintool", "w", "", - "Add a module to the whitelist. If none is specified, every module is white-listed. Example: calc.exe"); - -KNOB KnobOutputFilePrefix(KNOB_MODE_WRITEONCE, "pintool", "o", "trace", - "Prefix of the output file. If none is specified, 'trace' is used."); - -// -// Misc / Util -// - -#if defined(TARGET_WINDOWS) -#define PATH_SEPARATOR "\\" -#else -#define PATH_SEPARATOR "/" -#endif - -static std::string base_name(const std::string& path) -{ - std::string::size_type idx = path.rfind(PATH_SEPARATOR); - std::string name = (idx == std::string::npos) ? path : path.substr(idx + 1); - return name; -} - -// -// Per thread data structure. This is mainly done to avoid locking. -// - Per-thread map of executed basic blocks, and their size. -// - -struct ThreadData -{ - ADDRINT m_cpu_pc; - ADDRINT m_cpu[REG_GR_LAST+1]; - - ADDRINT mem_w_addr; - ADDRINT mem_w_size; - ADDRINT mem_r_addr; - ADDRINT mem_r_size; - ADDRINT mem_r2_addr; - ADDRINT mem_r2_size; - - // Trace file for thread-specific trace modes - ofstream* m_trace; - - char m_scratch[512 * 2]; // fxsave has the biggest memory operand -}; - -// -// Tool Infrastructure -// - -class ToolContext -{ -public: - - ToolContext() - { - PIN_InitLock(&m_loaded_images_lock); - PIN_InitLock(&m_thread_lock); - m_tls_key = PIN_CreateThreadDataKey(nullptr); - } - - ThreadData* GetThreadLocalData(THREADID tid) - { - return static_cast(PIN_GetThreadData(m_tls_key, tid)); - } - - void setThreadLocalData(THREADID tid, ThreadData* data) - { - PIN_SetThreadData(m_tls_key, data, tid); - } - - // The image manager allows us to keep track of loaded images. - ImageManager* m_images; - - // Trace file used for 'monolithic' execution traces. - //TraceFile* m_trace; - - // Keep track of _all_ the loaded images. - std::vector m_loaded_images; - PIN_LOCK m_loaded_images_lock; - - // Thread tracking utilities. - std::set m_seen_threads; - std::vector m_terminated_threads; - PIN_LOCK m_thread_lock; - - // Flag that indicates that tracing is enabled. Always true if there are no whitelisted images. - bool m_tracing_enabled = true; - - // TLS key used to store per-thread data. - TLS_KEY m_tls_key; -}; - -// Thread creation event handler. -static VOID OnThreadStart(THREADID tid, CONTEXT* ctxt, INT32 flags, VOID* v) -{ - // Create a new 'ThreadData' object and set it on the TLS. - auto& context = *reinterpret_cast(v); - auto data = new ThreadData; - memset(data, 0, sizeof(ThreadData)); - - data->m_trace = new ofstream; - context.setThreadLocalData(tid, data); - - char filename[128] = {}; - sprintf(filename, "%s.%u.log", KnobOutputFilePrefix.Value().c_str(), tid); - data->m_trace->open(filename); - *data->m_trace << std::hex; - - // Save the recently created thread. - PIN_GetLock(&context.m_thread_lock, 1); - { - context.m_seen_threads.insert(tid); - } - PIN_ReleaseLock(&context.m_thread_lock); - -} - -// Thread destruction event handler. -static VOID OnThreadFini(THREADID tid, const CONTEXT* ctxt, INT32 c, VOID* v) -{ - // Get thread's 'ThreadData' structure. - auto& context = *reinterpret_cast(v); - ThreadData* data = context.GetThreadLocalData(tid); - - // Remove the thread from the seen threads set and add it to the terminated list. - PIN_GetLock(&context.m_thread_lock, 1); - { - context.m_seen_threads.erase(tid); - context.m_terminated_threads.push_back(data); - } - PIN_ReleaseLock(&context.m_thread_lock); - -} - -// Image unload event handler. -static VOID OnImageLoad(IMG img, VOID* v) -{ - auto& context = *reinterpret_cast(v); - std::string img_name = base_name(IMG_Name(img)); - - ADDRINT low = IMG_LowAddress(img); - ADDRINT high = IMG_HighAddress(img); - - *g_log << "Loaded image: 0x" << low << ":0x" << high << " -> " << img_name << std::endl; - - // Save the loaded image with its original full name/path. - PIN_GetLock(&context.m_loaded_images_lock, 1); - { - context.m_loaded_images.push_back(LoadedImage(IMG_Name(img), low, high)); - } - PIN_ReleaseLock(&context.m_loaded_images_lock); - - // If the image is whitelisted save its information. - if (context.m_images->isWhiteListed(img_name)) - { - context.m_images->addImage(img_name, low, high); - - // Enable tracing if not already enabled. - if (!context.m_tracing_enabled) - context.m_tracing_enabled = true; - } -} - -// Image load event handler. -static VOID OnImageUnload(IMG img, VOID* v) -{ - auto& context = *reinterpret_cast(v); - context.m_images->removeImage(IMG_LowAddress(img)); -} - -// -// Tracing -// - -VOID record_diff(const CONTEXT * cpu, ADDRINT pc, VOID* v) -{ - auto& context = *reinterpret_cast(v); - //printf("Hello from record diff!\n"); - - if (!context.m_tracing_enabled || !context.m_images->isInterestingAddress(pc)) - return; - - auto tid = PIN_ThreadId(); - ThreadData* data = context.GetThreadLocalData(tid); - - // - // dump register delta - // - - ADDRINT val; - auto OutFile = data->m_trace; - - for (int reg = (int)REG_GR_BASE; reg <= (int)REG_GR_LAST; ++reg) { - - // fetch the current register value - PIN_GetContextRegval(cpu, (REG)reg, reinterpret_cast(&val)); - - // if the register didn't change from the last state, nothing to do - if (val == data->m_cpu[reg]) - continue; - - // save the value for the new register to the log - *OutFile << REG_StringShort( (REG) reg) << "=0x" << val << ","; - data->m_cpu[reg] = val; - } - - // always save pc to the log, for every unit of execution - *OutFile << PC << "=0x" << pc; - - // - // dump memory reads / writes - // - - if (data->mem_r_size) - { - memset(data->m_scratch, 0, data->mem_r_size); - - PIN_SafeCopy(data->m_scratch, (const VOID *)data->mem_r_addr, data->mem_r_size); - *OutFile << ",mr=0x" << data->mem_r_addr << ":"; - - for(UINT32 i = 0; i < data->mem_r_size; i++) { - *OutFile << std::hex << std::setw(2) << std::setfill('0') << ((unsigned char)data->m_scratch[i] & 0xff); - } - - data->mem_r_size = 0; - } - - if (data->mem_r2_size) - { - memset(data->m_scratch, 0, data->mem_r2_size); - - PIN_SafeCopy(data->m_scratch, (const VOID *)data->mem_r2_addr, data->mem_r2_size); - *OutFile << ",mr=0x" << data->mem_r2_addr << ":"; - - for(UINT32 i = 0; i < data->mem_r2_size; i++) { - *OutFile << std::hex << std::setw(2) << std::setfill('0') << ((unsigned char)data->m_scratch[i] & 0xff); - } - - data->mem_r2_size = 0; - } - - if (data->mem_w_size) - { - memset(data->m_scratch, 0, data->mem_w_size); - - PIN_SafeCopy(data->m_scratch, (const VOID *)data->mem_w_addr, data->mem_w_size); - *OutFile << ",mw=0x" << data->mem_w_addr << ":"; - - for(UINT32 i = 0; i < data->mem_w_size; i++) { - *OutFile << std::hex << std::setw(2) << std::setfill('0') << ((unsigned char)data->m_scratch[i] & 0xff); - } - - data->mem_w_size = 0; - } - - *OutFile << std::endl; -} - -VOID record_read(THREADID tid, ADDRINT access_addr, UINT32 access_size, VOID * v) { - auto& context = *reinterpret_cast(v); - ThreadData* data = context.GetThreadLocalData(tid); - data->mem_r_addr = access_addr; - data->mem_r_size = access_size; -} - -VOID record_read2(THREADID tid, ADDRINT access_addr, UINT32 access_size, VOID * v) { - auto& context = *reinterpret_cast(v); - ThreadData* data = context.GetThreadLocalData(tid); - data->mem_r2_addr = access_addr; - data->mem_r2_size = access_size; -} - -VOID record_write(THREADID tid, ADDRINT access_addr, UINT32 access_size, VOID * v) { - auto& context = *reinterpret_cast(v); - ThreadData* data = context.GetThreadLocalData(tid); - data->mem_w_addr = access_addr; - data->mem_w_size = access_size; -} - -VOID OnInst(INS ins, VOID* v) { - - // - // *always* dump a diff since the last instruction - // - - INS_InsertCall( - ins, IPOINT_BEFORE, - AFUNPTR(record_diff), - IARG_CONST_CONTEXT, - IARG_INST_PTR, - IARG_PTR, v, - IARG_END); - - // - // if this instruction will perform a mem r/w, inject a call to record the - // address of interest. this will be used by the *next* state diff / dump - // - - if (INS_IsMemoryRead(ins) || INS_IsMemoryWrite(ins)) - { - if (INS_IsMemoryRead(ins)) - { - INS_InsertCall( - ins, IPOINT_BEFORE, - AFUNPTR(record_read), - IARG_THREAD_ID, - IARG_MEMORYREAD_EA, - IARG_MEMORYREAD_SIZE, - IARG_PTR, v, - IARG_END); - } - - if (INS_HasMemoryRead2(ins)) - { - //assert(INS_IsMemoryRead(ins) == false); - INS_InsertCall( - ins, IPOINT_BEFORE, - AFUNPTR(record_read2), - IARG_THREAD_ID, - IARG_MEMORYREAD2_EA, - IARG_MEMORYREAD_SIZE, - IARG_PTR, v, - IARG_END); - } - - if (INS_IsMemoryWrite(ins)) - { - INS_InsertCall( - ins, IPOINT_BEFORE, - AFUNPTR(record_write), - IARG_THREAD_ID, - IARG_MEMORYWRITE_EA, - IARG_MEMORYWRITE_SIZE, - IARG_PTR, v, - IARG_END); - } - } - -} - -static VOID Fini(INT32 code, VOID *v) -{ - auto& context = *reinterpret_cast(v); - - // Add non terminated threads to the list of terminated threads. - for (THREADID i : context.m_seen_threads) { - ThreadData* data = context.GetThreadLocalData(i); - context.m_terminated_threads.push_back(data); - } - - for (const auto& data : context.m_terminated_threads) { - data->m_trace->close(); - } - - g_log->close(); -} - -int main(int argc, char * argv[]) { - - // Initialize symbol processing - PIN_InitSymbols(); - - // Initialize PIN. - if (PIN_Init(argc, argv)) { - std::cerr << "Error initializing PIN, PIN_Init failed!" << std::endl; - return -1; - } - - auto logFile = KnobOutputFilePrefix.Value() + ".log"; - g_log = new ofstream; - g_log->open(logFile.c_str()); - *g_log << std::hex; - - // Initialize the tool context - ToolContext *context = new ToolContext(); - context->m_images = new ImageManager(); - - for (unsigned i = 0; i < KnobModuleWhitelist.NumberOfValues(); ++i) { - *g_log << "White-listing image: " << KnobModuleWhitelist.Value(i) << '\n'; - context->m_images->addWhiteListedImage(KnobModuleWhitelist.Value(i)); - context->m_tracing_enabled = false; - } - - // Handlers for thread creation and destruction. - PIN_AddThreadStartFunction(OnThreadStart, context); - PIN_AddThreadFiniFunction(OnThreadFini, context); - - // Handlers for image loading and unloading. - IMG_AddInstrumentFunction(OnImageLoad, context); - IMG_AddUnloadFunction(OnImageUnload, context); - - // Handlers for instrumentation events. - INS_AddInstrumentFunction(OnInst, context); - - // Handler for program exits. - PIN_AddFiniFunction(Fini, context); - - PIN_StartProgram(); - return 0; -} +// +// pintenet.cpp, a Proof-of-Concept Tenet Tracer +// +// -- by Patrick Biernat & Markus Gaasedelen +// @ RET2 Systems, Inc. +// +// Adaptions from the CodeCoverage pin tool by Agustin Gianni as +// contributed to Lighthouse: https://github.com/gaasedelen/lighthouse +// + +#include +#include +#include + +#include "pin.H" +#include "ImageManager.h" + +using std::ofstream; + +ofstream* g_log; + +#ifdef __i386__ +#define PC "eip" +#else +#define PC "rip" +#endif + +// +// Tool Arguments +// + +static KNOB KnobModuleWhitelist(KNOB_MODE_APPEND, "pintool", "w", "", + "Add a module to the whitelist. If none is specified, every module is white-listed. Example: calc.exe"); + +KNOB KnobOutputFilePrefix(KNOB_MODE_WRITEONCE, "pintool", "o", "trace", + "Prefix of the output file. If none is specified, 'trace' is used."); + +// +// Misc / Util +// + +#if defined(TARGET_WINDOWS) +#define PATH_SEPARATOR "\\" +#else +#define PATH_SEPARATOR "/" +#endif + +static std::string base_name(const std::string& path) +{ + std::string::size_type idx = path.rfind(PATH_SEPARATOR); + std::string name = (idx == std::string::npos) ? path : path.substr(idx + 1); + return name; +} + +// +// Per thread data structure. This is mainly done to avoid locking. +// - Per-thread map of executed basic blocks, and their size. +// + +struct ThreadData +{ + ADDRINT m_cpu_pc; + ADDRINT m_cpu[REG_GR_LAST+1]; + + ADDRINT mem_w_addr; + ADDRINT mem_w_size; + ADDRINT mem_r_addr; + ADDRINT mem_r_size; + ADDRINT mem_r2_addr; + ADDRINT mem_r2_size; + + // Trace file for thread-specific trace modes + ofstream* m_trace; + + char m_scratch[512 * 2]; // fxsave has the biggest memory operand +}; + +// +// Tool Infrastructure +// + +class ToolContext +{ +public: + + ToolContext() + { + PIN_InitLock(&m_loaded_images_lock); + PIN_InitLock(&m_thread_lock); + m_tls_key = PIN_CreateThreadDataKey(nullptr); + } + + ThreadData* GetThreadLocalData(THREADID tid) + { + return static_cast(PIN_GetThreadData(m_tls_key, tid)); + } + + void setThreadLocalData(THREADID tid, ThreadData* data) + { + PIN_SetThreadData(m_tls_key, data, tid); + } + + // The image manager allows us to keep track of loaded images. + ImageManager* m_images; + + // Trace file used for 'monolithic' execution traces. + //TraceFile* m_trace; + + // Keep track of _all_ the loaded images. + std::vector m_loaded_images; + PIN_LOCK m_loaded_images_lock; + + // Thread tracking utilities. + std::set m_seen_threads; + std::vector m_terminated_threads; + PIN_LOCK m_thread_lock; + + // Flag that indicates that tracing is enabled. Always true if there are no whitelisted images. + bool m_tracing_enabled = true; + + // TLS key used to store per-thread data. + TLS_KEY m_tls_key; +}; + +// Thread creation event handler. +static VOID OnThreadStart(THREADID tid, CONTEXT* ctxt, INT32 flags, VOID* v) +{ + // Create a new 'ThreadData' object and set it on the TLS. + auto& context = *reinterpret_cast(v); + auto data = new ThreadData; + memset(data, 0, sizeof(ThreadData)); + + data->m_trace = new ofstream; + context.setThreadLocalData(tid, data); + + char filename[128] = {}; + sprintf(filename, "%s.%u.log", KnobOutputFilePrefix.Value().c_str(), tid); + // NOTE: We do not open the file here because that can cause a deadlock. + // Instead, we will open it lazily on the first call to record_diff. + + // Save the recently created thread. + PIN_GetLock(&context.m_thread_lock, 1); + { + context.m_seen_threads.insert(tid); + } + PIN_ReleaseLock(&context.m_thread_lock); + +} + +// Thread destruction event handler. +static VOID OnThreadFini(THREADID tid, const CONTEXT* ctxt, INT32 c, VOID* v) +{ + // Get thread's 'ThreadData' structure. + auto& context = *reinterpret_cast(v); + ThreadData* data = context.GetThreadLocalData(tid); + + // Remove the thread from the seen threads set and add it to the terminated list. + PIN_GetLock(&context.m_thread_lock, 1); + { + context.m_seen_threads.erase(tid); + context.m_terminated_threads.push_back(data); + } + PIN_ReleaseLock(&context.m_thread_lock); + +} + +// Image unload event handler. +static VOID OnImageLoad(IMG img, VOID* v) +{ + auto& context = *reinterpret_cast(v); + std::string img_name = base_name(IMG_Name(img)); + + ADDRINT low = IMG_LowAddress(img); + ADDRINT high = IMG_HighAddress(img); + + *g_log << "Loaded image: 0x" << low << ":0x" << high << " -> " << img_name << std::endl; + + // Save the loaded image with its original full name/path. + PIN_GetLock(&context.m_loaded_images_lock, 1); + { + context.m_loaded_images.push_back(LoadedImage(IMG_Name(img), low, high)); + } + PIN_ReleaseLock(&context.m_loaded_images_lock); + + // If the image is whitelisted save its information. + if (context.m_images->isWhiteListed(img_name)) + { + context.m_images->addImage(img_name, low, high); + + // Enable tracing if not already enabled. + if (!context.m_tracing_enabled) + context.m_tracing_enabled = true; + } +} + +// Image load event handler. +static VOID OnImageUnload(IMG img, VOID* v) +{ + auto& context = *reinterpret_cast(v); + context.m_images->removeImage(IMG_LowAddress(img)); +} + +// +// Tracing +// + +VOID record_diff(const CONTEXT * cpu, ADDRINT pc, VOID* v) +{ + auto& context = *reinterpret_cast(v); + //printf("Hello from record diff!\n"); + + if (!context.m_tracing_enabled || !context.m_images->isInterestingAddress(pc)) + return; + + auto tid = PIN_ThreadId(); + ThreadData* data = context.GetThreadLocalData(tid); + + // + // dump register delta + // + + ADDRINT val; + auto OutFile = data->m_trace; + + // The file stream is not opened in OnThreadStart because it can cause a deadlock. + // We open it here on first use instead. + if (!OutFile->is_open()) + { + char filename[128] = {}; + sprintf(filename, "%s.%u.log", KnobOutputFilePrefix.Value().c_str(), tid); + OutFile->open(filename, std::ios_base::out | std::ios_base::binary); + *OutFile << std::hex; + } + + for (int reg = (int)REG_GR_BASE; reg <= (int)REG_GR_LAST; ++reg) { + + // fetch the current register value + PIN_GetContextRegval(cpu, (REG)reg, reinterpret_cast(&val)); + + // if the register didn't change from the last state, nothing to do + if (val == data->m_cpu[reg]) + continue; + + // save the value for the new register to the log + *OutFile << REG_StringShort( (REG) reg) << "=0x" << val << ","; + data->m_cpu[reg] = val; + } + + // always save pc to the log, for every unit of execution + *OutFile << PC << "=0x" << pc; + + // + // dump memory reads / writes + // + + if (data->mem_r_size) + { + memset(data->m_scratch, 0, data->mem_r_size); + + PIN_SafeCopy(data->m_scratch, (const VOID *)data->mem_r_addr, data->mem_r_size); + *OutFile << ",mr=0x" << data->mem_r_addr << ":"; + + for(UINT32 i = 0; i < data->mem_r_size; i++) { + *OutFile << std::hex << std::setw(2) << std::setfill('0') << ((unsigned char)data->m_scratch[i] & 0xff); + } + + data->mem_r_size = 0; + } + + if (data->mem_r2_size) + { + memset(data->m_scratch, 0, data->mem_r2_size); + + PIN_SafeCopy(data->m_scratch, (const VOID *)data->mem_r2_addr, data->mem_r2_size); + *OutFile << ",mr=0x" << data->mem_r2_addr << ":"; + + for(UINT32 i = 0; i < data->mem_r2_size; i++) { + *OutFile << std::hex << std::setw(2) << std::setfill('0') << ((unsigned char)data->m_scratch[i] & 0xff); + } + + data->mem_r2_size = 0; + } + + if (data->mem_w_size) + { + memset(data->m_scratch, 0, data->mem_w_size); + + PIN_SafeCopy(data->m_scratch, (const VOID *)data->mem_w_addr, data->mem_w_size); + *OutFile << ",mw=0x" << data->mem_w_addr << ":"; + + for(UINT32 i = 0; i < data->mem_w_size; i++) { + *OutFile << std::hex << std::setw(2) << std::setfill('0') << ((unsigned char)data->m_scratch[i] & 0xff); + } + + data->mem_w_size = 0; + } + + *OutFile << std::endl; +} + +VOID record_read(THREADID tid, ADDRINT access_addr, UINT32 access_size, VOID * v) { + auto& context = *reinterpret_cast(v); + ThreadData* data = context.GetThreadLocalData(tid); + if (!data) return; + data->mem_r_addr = access_addr; + data->mem_r_size = access_size; +} + +VOID record_read2(THREADID tid, ADDRINT access_addr, UINT32 access_size, VOID * v) { + auto& context = *reinterpret_cast(v); + ThreadData* data = context.GetThreadLocalData(tid); + if (!data) return; + data->mem_r2_addr = access_addr; + data->mem_r2_size = access_size; +} + +VOID record_write(THREADID tid, ADDRINT access_addr, UINT32 access_size, VOID * v) { + auto& context = *reinterpret_cast(v); + ThreadData* data = context.GetThreadLocalData(tid); + if (!data) return; + data->mem_w_addr = access_addr; + data->mem_w_size = access_size; +} + +VOID OnInst(INS ins, VOID* v) { + + // + // *always* dump a diff since the last instruction + // + + INS_InsertCall( + ins, IPOINT_BEFORE, + AFUNPTR(record_diff), + IARG_CONST_CONTEXT, + IARG_INST_PTR, + IARG_PTR, v, + IARG_END); + + // + // if this instruction will perform a mem r/w, inject a call to record the + // address of interest. this will be used by the *next* state diff / dump + // + + if (INS_IsMemoryRead(ins) || INS_IsMemoryWrite(ins)) + { + if (INS_IsMemoryRead(ins)) + { + INS_InsertCall( + ins, IPOINT_BEFORE, + AFUNPTR(record_read), + IARG_THREAD_ID, + IARG_MEMORYREAD_EA, + IARG_MEMORYREAD_SIZE, + IARG_PTR, v, + IARG_END); + } + + if (INS_HasMemoryRead2(ins)) + { + //assert(INS_IsMemoryRead(ins) == false); + INS_InsertCall( + ins, IPOINT_BEFORE, + AFUNPTR(record_read2), + IARG_THREAD_ID, + IARG_MEMORYREAD2_EA, + IARG_MEMORYREAD_SIZE, + IARG_PTR, v, + IARG_END); + } + + if (INS_IsMemoryWrite(ins)) + { + INS_InsertCall( + ins, IPOINT_BEFORE, + AFUNPTR(record_write), + IARG_THREAD_ID, + IARG_MEMORYWRITE_EA, + IARG_MEMORYWRITE_SIZE, + IARG_PTR, v, + IARG_END); + } + } + +} + +static VOID Fini(INT32 code, VOID *v) +{ + auto& context = *reinterpret_cast(v); + + // Add non terminated threads to the list of terminated threads. + for (THREADID i : context.m_seen_threads) { + ThreadData* data = context.GetThreadLocalData(i); + context.m_terminated_threads.push_back(data); + } + + for (const auto& data : context.m_terminated_threads) { + data->m_trace->close(); + } + + g_log->close(); +} + +int main(int argc, char* argv[]) { + + // Initialize symbol processing + PIN_InitSymbols(); + + // Initialize PIN. + if (PIN_Init(argc, argv)) { + std::cerr << "Error initializing PIN, PIN_Init failed!" << std::endl; + return -1; + } + + auto logFile = KnobOutputFilePrefix.Value() + ".log"; + g_log = new ofstream; + g_log->open(logFile.c_str()); + *g_log << std::hex; + + // Initialize the tool context + ToolContext *context = new ToolContext(); + context->m_images = new ImageManager(); + + for (unsigned i = 0; i < KnobModuleWhitelist.NumberOfValues(); ++i) { + *g_log << "White-listing image: " << KnobModuleWhitelist.Value(i) << '\n'; + context->m_images->addWhiteListedImage(KnobModuleWhitelist.Value(i)); + context->m_tracing_enabled = false; + } + + // Handlers for thread creation and destruction. + PIN_AddThreadStartFunction(OnThreadStart, context); + PIN_AddThreadFiniFunction(OnThreadFini, context); + + // Handlers for image loading and unloading. + IMG_AddInstrumentFunction(OnImageLoad, context); + IMG_AddUnloadFunction(OnImageUnload, context); + + // Handlers for instrumentation events. + INS_AddInstrumentFunction(OnInst, context); + + // Handler for program exits. + PIN_AddFiniFunction(Fini, context); + + PIN_StartProgram(); + return 0; +}