diff --git a/.github/workflows/testAndPublish.yml b/.github/workflows/testAndPublish.yml index f4990055655..928d75875f6 100644 --- a/.github/workflows/testAndPublish.yml +++ b/.github/workflows/testAndPublish.yml @@ -401,6 +401,7 @@ jobs: - startupShutdown - symbols - vscode + - imageDescriptions - chrome_annotations - chrome_list - chrome_table diff --git a/pyproject.toml b/pyproject.toml index 0e41fd3bc15..3fc86063d16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,9 @@ dependencies = [ "l2m4m==1.0.4", "pyyaml==6.0.3", "pymdown-extensions==10.17.1", + # local image caption + "onnxruntime==1.23.2", + "numpy==2.3.5", ] [project.urls] @@ -335,6 +338,7 @@ system-tests = [ "robotframework==7.3.2", "robotremoteserver==1.1.1", "robotframework-screencaplibrary==1.6.0", + "onnx==1.19.1", ] unit-tests = [ # Creating XML unit test reports diff --git a/source/NVDAState.py b/source/NVDAState.py index 6f6b079aab1..f08f227dd83 100644 --- a/source/NVDAState.py +++ b/source/NVDAState.py @@ -67,6 +67,10 @@ def voiceDictsBackupDir(self) -> str: def updatesDir(self) -> str: return os.path.join(self.configDir, "updates") + @property + def modelsDir(self) -> str: + return os.path.join(self.configDir, "models") + @property def nvdaConfigFile(self) -> str: return os.path.join(self.configDir, "nvda.ini") diff --git a/source/_localCaptioner/__init__.py b/source/_localCaptioner/__init__.py new file mode 100644 index 00000000000..3d55b5c486a --- /dev/null +++ b/source/_localCaptioner/__init__.py @@ -0,0 +1,44 @@ +# A part of NonVisual Desktop Access (NVDA) +# Copyright (C) 2025 NV Access Limited, Tianze +# This file may be used under the terms of the GNU General Public License, version 2 or later, as modified by the NVDA license. +# For full terms and any additional permissions, see the NVDA license file: https://github.com/nvaccess/nvda/blob/master/copying.txt + +from logHandler import log + +from .imageDescriber import ImageDescriber +from . import modelConfig + +_localCaptioner: ImageDescriber | None = None + + +def initialize(): + """Initialise the local captioner.""" + global _localCaptioner + log.debug("Initializing local captioner") + modelConfig.initialize() + _localCaptioner = ImageDescriber() + + +def terminate(): + """Terminate the local captioner.""" + global _localCaptioner + if _localCaptioner is None: + log.error("local captioner not running") + return + log.debug("Terminating local captioner") + _localCaptioner.terminate() + _localCaptioner = None + + +def isModelLoaded() -> bool: + """return if model is loaded""" + if _localCaptioner is not None: + return _localCaptioner.isModelLoaded + else: + return False + + +def toggleImageCaptioning() -> None: + """do load/unload the model from memory.""" + if _localCaptioner is not None: + _localCaptioner.toggleSwitch() diff --git a/source/_localCaptioner/captioner/__init__.py b/source/_localCaptioner/captioner/__init__.py new file mode 100644 index 00000000000..a1f16590a0c --- /dev/null +++ b/source/_localCaptioner/captioner/__init__.py @@ -0,0 +1,53 @@ +# A part of NonVisual Desktop Access (NVDA) +# Copyright (C) 2025 NV Access Limited, Tianze +# This file may be used under the terms of the GNU General Public License, version 2 or later, as modified by the NVDA license. +# For full terms and any additional permissions, see the NVDA license file: https://github.com/nvaccess/nvda/blob/master/copying.txt + +import json + +from logHandler import log +from .base import ImageCaptioner + + +def imageCaptionerFactory( + configPath: str, + encoderPath: str | None = None, + decoderPath: str | None = None, + monomericModelPath: str | None = None, +) -> ImageCaptioner: + """Initialize the image caption generator. + + :param monomericModelPath: Path to a single merged model file. + :param encoderPath: Path to the encoder model file. + :param decoderPath: Path to the decoder model file. + :param configPath: Path to the configuration file. + :raises ValueError: If neither a single model nor both encoder and decoder are provided. + :raises FileNotFoundError: If config file not found. + :raises NotImplementedError: if model architecture is unsupported + :raises Exception: If config.json fail to load. + :return: instance of ImageCaptioner + """ + if not monomericModelPath and not (encoderPath and decoderPath): + raise ValueError( + "You must provide either 'monomericModelPath' or both 'encoderPath' and 'decoderPath'.", + ) + + try: + with open(configPath, "r", encoding="utf-8") as f: + config = json.load(f) + except FileNotFoundError: + raise FileNotFoundError( + f"Caption model config file {configPath} not found, " + "please download models and config file first!", + ) + except Exception: + log.exception("config file not found") + raise + + modelArchitecture = config["architectures"][0] + if modelArchitecture == "VisionEncoderDecoderModel": + from .vitGpt2 import VitGpt2ImageCaptioner + + return VitGpt2ImageCaptioner(encoderPath, decoderPath, configPath) + else: + raise NotImplementedError("Unsupported model architectures") diff --git a/source/_localCaptioner/captioner/base.py b/source/_localCaptioner/captioner/base.py new file mode 100644 index 00000000000..ba7ea116f79 --- /dev/null +++ b/source/_localCaptioner/captioner/base.py @@ -0,0 +1,24 @@ +# A part of NonVisual Desktop Access (NVDA) +# Copyright (C) 2025 NV Access Limited, Tianze +# This file may be used under the terms of the GNU General Public License, version 2 or later, as modified by the NVDA license. +# For full terms and any additional permissions, see the NVDA license file: https://github.com/nvaccess/nvda/blob/master/copying.txt + +from abc import ABC, abstractmethod + + +class ImageCaptioner(ABC): + """Abstract interface for image caption generation. + + Supports generate caption for image + """ + + @abstractmethod + def generateCaption(self, image: str | bytes, maxLength: int | None = None) -> str: + """ + Generate a caption for the given image. + + :param image: Image file path or binary data. + :param maxLength: Optional maximum length for the generated caption. + :return: The generated image caption as a string. + """ + pass diff --git a/source/_localCaptioner/captioner/vitGpt2.py b/source/_localCaptioner/captioner/vitGpt2.py new file mode 100644 index 00000000000..47af56c9266 --- /dev/null +++ b/source/_localCaptioner/captioner/vitGpt2.py @@ -0,0 +1,382 @@ +# A part of NonVisual Desktop Access (NVDA) +# Copyright (C) 2025 NV Access Limited, Tianze +# This file may be used under the terms of the GNU General Public License, version 2 or later, as modified by the NVDA license. +# For full terms and any additional permissions, see the NVDA license file: https://github.com/nvaccess/nvda/blob/master/copying.txt + +import os +import json +import re +import io +from functools import lru_cache + +import numpy as np +from PIL import Image + +from logHandler import log + +from .base import ImageCaptioner +from ..modelConfig import ( + _EncoderConfig, + _DecoderConfig, + _GenerationConfig, + _ModelConfig, + _PreprocessorConfig, + _createConfigFromDict, +) +from .. import modelConfig + + +class VitGpt2ImageCaptioner(ImageCaptioner): + """Lightweight ONNX Runtime image captioning model. + + This class provides image captioning functionality using ONNX models + without PyTorch dependencies. It uses a Vision Transformer encoder + and GPT-2 decoder for generating captions. + """ + + def __init__( + self, + encoderPath: str, + decoderPath: str, + configPath: str, + enableProfiling: bool = False, + ) -> None: + """Initialize the lightweight ONNX image captioning model. + + :param encoderPath: Path to the ViT encoder ONNX model. + :param decoderPath: Path to the GPT-2 decoder ONNX model. + :param configPath: Path to the configuration file (required). + :param enableProfiling: Whether to enable ONNX Runtime profiling. + :raises FileNotFoundError: If config file is not found. + :raises Exception: If model initialization fails. + """ + # Import late to avoid importing numpy at initialization + import onnxruntime as ort + + # Load configuration file + try: + with open(configPath, "r", encoding="utf-8") as f: + self.config = json.load(f) + except FileNotFoundError: + raise FileNotFoundError( + f"Caption model config file {configPath} not found, " + "please download models and config file first!", + ) + except Exception: + raise + + # Load vocabulary from vocab.json in the same directory as config + configDir = os.path.dirname(configPath) + vocabPath = os.path.join(configDir, "vocab.json") + self.vocab = self._loadVocab(vocabPath) + self.vocabSize = len(self.vocab) + + preprocessorPath = os.path.join(configDir, "preprocessor_config.json") + self.preprocessorConfig = self._loadPreprocessorConfig(preprocessorPath) + + # Load all model parameters from configuration + self._loadModelParams() + + # Configure ONNX Runtime session + sessionOptions = ort.SessionOptions() + if enableProfiling: + sessionOptions.enable_profiling = True + + # Load ONNX models + try: + self.encoderSession = ort.InferenceSession(encoderPath, sess_options=sessionOptions) + self.decoderSession = ort.InferenceSession(decoderPath, sess_options=sessionOptions) + except ( + ort.capi.onnxruntime_pybind11_state.InvalidProtobuf, + ort.capi.onnxruntime_pybind11_state.NoSuchFile, + ) as e: + raise FileNotFoundError( + "model file incomplete" + f" Please check whether the file is complete or re-download. Original error: {e}", + ) from e + + log.debug( + f"Loaded ONNX models - Encoder: {os.path.basename(encoderPath)}, Decoder: {os.path.basename(decoderPath)}", + ) + log.debug(f"Loaded config : {os.path.basename(configPath)}") + log.debug(f"Loaded vocabulary : {os.path.basename(vocabPath)}") + log.debug( + f"Model config - Image size: {self.encoderConfig.image_size}, Max length: {self.decoderConfig.max_length}", + ) + + def _loadModelParams(self) -> None: + """Load all model parameters from configuration file.""" + # Load encoder configuration + encoder_dict = self.config.get("encoder", {}) + self.encoderConfig = _createConfigFromDict( + _EncoderConfig, + encoder_dict, + modelConfig._DEFAULT_ENCODER_CONFIG, + ) + + # Load decoder configuration + decoder_dict = self.config.get("decoder", {}) + self.decoderConfig = _createConfigFromDict( + _DecoderConfig, + decoder_dict, + modelConfig._DEFAULT_DECODER_CONFIG, + ) + + # Load generation configuration + generation_dict = self.config.get("generation", {}) + self.generationConfig = _createConfigFromDict( + _GenerationConfig, + generation_dict, + modelConfig._DEFAULT_GENERATION_CONFIG, + ) + + # Load main model configuration + self.modelConfig = _createConfigFromDict(_ModelConfig, self.config, modelConfig._DEFAULT_MODEL_CONFIG) + + def _loadVocab(self, vocabPath: str) -> dict[int, str]: + """Load vocabulary file. + + :param vocabPath: Path to vocab.json file. + :return: Dictionary mapping token IDs to tokens. + """ + try: + with open(vocabPath, "r", encoding="utf-8") as f: + vocabData = json.load(f) + + # Convert to id -> token format + vocab = {v: k for k, v in vocabData.items()} + log.debug(f"Successfully loaded vocabulary with {len(vocab)} tokens") + return vocab + + except FileNotFoundError: + log.exception(f"vocab.json not found at {vocabPath}") + raise + except Exception: + log.exception(f"Could not load vocabulary from {vocabPath}") + raise + + def _loadPreprocessorConfig(self, preprocessorPath: str) -> _PreprocessorConfig: + """Load preprocessor configuration from preprocessor_config.json.""" + try: + with open(preprocessorPath, "r", encoding="utf-8") as f: + preprocessor_dict = json.load(f) + except FileNotFoundError: + log.warning("Preprocessor config not found, using defaults") + return modelConfig._DEFAULT_PREPROCESSOR_CONFIG + else: + return _createConfigFromDict( + _PreprocessorConfig, + preprocessor_dict, + modelConfig._DEFAULT_PREPROCESSOR_CONFIG, + ) + + def _preprocessImage(self, image: str | bytes) -> np.ndarray: + """Preprocess image for model input using external configuration. + + :param image: Image file path or binary data. + :return: Preprocessed image array ready for model input. + """ + # Load image + if isinstance(image, str) and os.path.isfile(image): + img = Image.open(image).convert("RGB") + else: + img = Image.open(io.BytesIO(image)).convert("RGB") + + # Resize image if configured + if self.preprocessorConfig.do_resize: + target_size = ( + self.preprocessorConfig.size["width"], + self.preprocessorConfig.size["height"], + ) + # Map resample integer to PIL constant + resample_map = { + 0: Image.NEAREST, + 1: Image.LANCZOS, + 2: Image.BILINEAR, + 3: Image.BICUBIC, + 4: Image.BOX, + 5: Image.HAMMING, + } + resample_method = resample_map.get(self.preprocessorConfig.resample, Image.LANCZOS) + img = img.resize(target_size, resample_method) + + # Convert to numpy array + imgArray = np.array(img).astype(np.float32) + + # Rescale if configured (typically from [0, 255] to [0, 1]) + if self.preprocessorConfig.do_rescale: + imgArray = imgArray * self.preprocessorConfig.rescale_factor + + # Normalize if configured + if self.preprocessorConfig.do_normalize: + mean = np.array(self.preprocessorConfig.image_mean, dtype=np.float32) + std = np.array(self.preprocessorConfig.image_std, dtype=np.float32) + imgArray = (imgArray - mean) / std + + # Adjust dimensions: (H, W, C) -> (1, C, H, W) + imgArray = np.transpose(imgArray, (2, 0, 1)) + imgArray = np.expand_dims(imgArray, axis=0) + + return imgArray + + def _encodeImage(self, imageArray: np.ndarray) -> np.ndarray: + """Encode image using ViT encoder. + + :param imageArray: Preprocessed image array. + :return: Encoder hidden states. + """ + # Get encoder input name + inputName = self.encoderSession.get_inputs()[0].name + + # Run encoder inference + imageArray = imageArray.astype(np.float32) + encoderOutputs = self.encoderSession.run(None, {inputName: imageArray}) + + # Return last hidden state + return encoderOutputs[0] + + def _decodeTokens(self, tokenIds: list[int]) -> str: + """Decode token IDs to text. + + :param tokenIds: List of token IDs. + :return: Decoded text string. + """ + tokens = [] + for tokenId in tokenIds: + if tokenId in self.vocab: + token = self.vocab[tokenId] + if token not in ["<|endoftext|>", "<|pad|>"]: + tokens.append(token) + + # Simple text post-processing + # Ġ (Unicode U+0120) is used by GPT-2 and RoBERTa to indicate space at the beginning of a word in their vocabulary + text = " ".join(tokens).replace("Ġ", " ") + + # Basic text cleaning + text = re.sub(r"\s+", " ", text) # Merge multiple spaces + text = text.strip() + + return text + + def _getDecoderInputNames(self) -> list[str]: + """Get decoder input names for debugging. + + :returns: List of decoder input names. + """ + return [inp.name for inp in self.decoderSession.get_inputs()] + + def _getDecoderOutputNames(self) -> list[str]: + """Get decoder output names for debugging. + + :return: List of decoder output names. + """ + return [out.name for out in self.decoderSession.get_outputs()] + + def _initializePastKeyValues(self, batchSize: int = 1) -> dict[str, np.ndarray]: + """Initialize past_key_values for decoder. + + :param batchSize: Batch size for inference. + :return: Dictionary of initialized past key values. + """ + pastKeyValues = {} + + # Create key and value for each layer + for layerIdx in range(self.decoderConfig.n_layer): + # Key and value shape: (batch_size, num_heads, 0, head_dim) + # Initial sequence length is 0 + headDim = self.decoderConfig.n_embd // self.decoderConfig.n_head + + keyShape = (batchSize, self.decoderConfig.n_head, 0, headDim) + valueShape = (batchSize, self.decoderConfig.n_head, 0, headDim) + + pastKeyValues[f"past_key_values.{layerIdx}.key"] = np.zeros(keyShape, dtype=np.float32) + pastKeyValues[f"past_key_values.{layerIdx}.value"] = np.zeros(valueShape, dtype=np.float32) + + return pastKeyValues + + def _generateWithGreedy( + self, + encoderHiddenStates: np.ndarray, + maxLength: int | None = None, + ) -> str: + """Generate text using greedy search. + + + :param encoderHiddenStates: Encoder hidden states. + :param maxLength: Maximum generation length. + :return: Generated text string. + """ + if maxLength is None: + maxLength = self.decoderConfig.max_length + + # Initialize input sequence + inputIds = np.array([[self.modelConfig.bos_token_id]], dtype=np.int64) + generatedTokens = [] + + # Initialize past_key_values + pastKeyValues = self._initializePastKeyValues(batchSize=1) + + for step in range(maxLength): + # Prepare decoder inputs + decoderInputs = { + "input_ids": inputIds if step == 0 else np.array([[generatedTokens[-1]]], dtype=np.int64), + "encoder_hidden_states": encoderHiddenStates, + "use_cache_branch": np.array([1], dtype=np.bool_), + } + + # Add past_key_values to inputs + decoderInputs.update(pastKeyValues) + + # Run decoder + decoderOutputs = self.decoderSession.run(None, decoderInputs) + logits = decoderOutputs[0] # Shape: (batch_size, seq_len, vocab_size) + + # Greedy selection of next token + nextTokenLogits = logits[0, -1, :] # Logits for last position + nextTokenId = int(np.argmax(nextTokenLogits)) + + # Check if generation should end + if nextTokenId == self.modelConfig.eos_token_id: + break + + generatedTokens.append(nextTokenId) + + # Update past_key_values from outputs + if len(decoderOutputs) > 1: + for layerIdx in range(self.decoderConfig.n_layer): + if len(decoderOutputs) > 1 + layerIdx * 2 + 1: + # [3] -> layer1 key, [4] -> layer1 value + keyIndex = 1 + layerIdx * 2 + valueIndex = keyIndex + 1 + pastKeyValues[f"past_key_values.{layerIdx}.key"] = decoderOutputs[keyIndex] + pastKeyValues[f"past_key_values.{layerIdx}.value"] = decoderOutputs[valueIndex] + + # Avoid sequences that are too long + if len(generatedTokens) >= self.decoderConfig.n_ctx - 1: + break + + # Decode generated text + return self._decodeTokens(generatedTokens) + + @lru_cache() + def generateCaption( + self, + image: str | bytes, + maxLength: int | None = None, + ) -> str: + """Generate image caption. + + :param image: Image file path or binary data. + :param maxLength: Maximum generation length. + :return: Generated image caption. + """ + # Preprocess image + imageArray = self._preprocessImage(image) + + # Encode image + encoderHiddenStates = self._encodeImage(imageArray) + + # Generate text + caption = self._generateWithGreedy(encoderHiddenStates, maxLength) + + return caption diff --git a/source/_localCaptioner/imageDescriber.py b/source/_localCaptioner/imageDescriber.py new file mode 100644 index 00000000000..1e193789ebf --- /dev/null +++ b/source/_localCaptioner/imageDescriber.py @@ -0,0 +1,214 @@ +# A part of NonVisual Desktop Access (NVDA) +# Copyright (C) 2025 NV Access Limited, Tianze +# This file may be used under the terms of the GNU General Public License, version 2 or later, as modified by the NVDA license. +# For full terms and any additional permissions, see the NVDA license file: https://github.com/nvaccess/nvda/blob/master/copying.txt + +"""ImageDescriber module for NVDA. + +This module provides local image captioning functionality using ONNX models. +It allows users to capture screen regions and generate captions using local AI models. +""" + +import io +import threading +from threading import Thread +import os + +import wx +import config +from logHandler import log +import ui +import api +from keyboardHandler import KeyboardInputGesture +from NVDAState import WritePaths +import core + +from .captioner import ImageCaptioner +from .captioner import imageCaptionerFactory + + +# Module-level configuration +_localCaptioner = None + + +def _screenshotNavigator() -> bytes: + """Capture a screenshot of the current navigator object. + + :Return: The captured image data as bytes in JPEG format. + """ + # Get the currently focused object on screen + obj = api.getNavigatorObject() + + # Get the object's position and size information + x, y, width, height = obj.location + + # Create a bitmap with the same size as the object + bmp = wx.Bitmap(width, height) + + # Create a memory device context for drawing operations on the bitmap + mem = wx.MemoryDC(bmp) + + # Copy the specified screen region to the memory bitmap + mem.Blit(0, 0, width, height, wx.ScreenDC(), x, y) + + # Convert the bitmap to an image object for more flexible operations + image = bmp.ConvertToImage() + + # Create a byte stream object to save image data as binary data + body = io.BytesIO() + + # Save the image to the byte stream in JPEG format + image.SaveFile(body, wx.BITMAP_TYPE_JPEG) + + # Read the binary image data from the byte stream + imageData = body.getvalue() + return imageData + + +def _messageCaption(captioner: ImageCaptioner, imageData: bytes) -> None: + """Generate a caption for the given image data. + + :param captioner: The captioner instance to use for generation. + :param imageData: The image data to caption. + """ + try: + description = captioner.generateCaption(image=imageData) + except Exception: + # Translators: error message when an image description cannot be generated + wx.CallAfter(ui.message, pgettext("imageDesc", "Failed to generate description")) + log.exception("Failed to generate caption") + else: + wx.CallAfter( + ui.message, + # Translators: Presented when an AI image description has been generated. + # {description} will be replaced with the generated image description. + pgettext("imageDesc", "Could be: {description}").format(description=description), + ) + + +class ImageDescriber: + """module for local image caption functionality. + + This module provides image captioning using local ONNX models. + It can capture screen regions and generate descriptive captions. + """ + + def __init__(self) -> None: + self.isModelLoaded = False + self.captioner: ImageCaptioner | None = None + self.captionThread: Thread | None = None + self.loadModelThread: Thread | None = None + + enable = config.conf["automatedImageDescriptions"]["enable"] + # Load model when initializing (may cause high memory usage) + if enable: + core.postNvdaStartup.register(self.loadModelInBackground) + + def terminate(self): + for t in [self.captionThread, self.loadModelThread]: + if t is not None and t.is_alive(): + t.join() + + self.captioner = None + + def runCaption(self, gesture: KeyboardInputGesture) -> None: + """Script to run image captioning on the current navigator object. + + :param gesture: The input gesture that triggered this script. + """ + self._doCaption() + + def _doCaption(self) -> None: + """Real logic to run image captioning on the current navigator object.""" + imageData = _screenshotNavigator() + + if not self.isModelLoaded: + from gui._localCaptioner.messageDialogs import openEnableOnceDialog + + # Ask to enable image desc only in this session, No configuration modifications + wx.CallAfter(openEnableOnceDialog) + return + + if self.captionThread is not None and self.captionThread.is_alive(): + return + + self.captionThread = threading.Thread( + target=_messageCaption, + args=(self.captioner, imageData), + name="RunCaptionThread", + ) + # Translators: Message when starting image recognition + ui.message(pgettext("imageDesc", "getting image description...")) + self.captionThread.start() + + def _loadModel(self, localModelDirPath: str | None = None) -> None: + """Load the ONNX model for image captioning. + + :param localModelDirPath: path of model directory + """ + + if not localModelDirPath: + baseModelsDir = WritePaths.modelsDir + localModelDirPath = os.path.join( + baseModelsDir, + config.conf["automatedImageDescriptions"]["defaultModel"], + ) + encoderPath = f"{localModelDirPath}/onnx/encoder_model_quantized.onnx" + decoderPath = f"{localModelDirPath}/onnx/decoder_model_merged_quantized.onnx" + configPath = f"{localModelDirPath}/config.json" + + try: + self.captioner = imageCaptionerFactory( + encoderPath=encoderPath, + decoderPath=decoderPath, + configPath=configPath, + ) + except FileNotFoundError: + self.isModelLoaded = False + from gui._localCaptioner.messageDialogs import ImageDescDownloader + + descDownloader = ImageDescDownloader() + wx.CallAfter(descDownloader.openDownloadDialog) + except Exception: + self.isModelLoaded = False + # Translators: error message when fail to load model + wx.CallAfter(ui.message, pgettext("imageDesc", "failed to load image captioner")) + log.exception("Failed to load image captioner model") + else: + self.isModelLoaded = True + # Translators: Message when successfully load the model + wx.CallAfter(ui.message, pgettext("imageDesc", "image captioning on")) + + def loadModelInBackground(self, localModelDirPath: str | None = None) -> None: + """load model in child thread + + :param localModelDirPath: path of model directory + """ + self.loadModelThread = threading.Thread( + target=self._loadModel, + args=(localModelDirPath,), + name="LoadModelThread", + ) + self.loadModelThread.start() + + def _doReleaseModel(self) -> None: + if hasattr(self, "captioner") and self.captioner: + del self.captioner + self.captioner = None + # Translators: Message when image captioning terminates + ui.message(pgettext("imageDesc", "image captioning off")) + self.isModelLoaded = False + + def toggleSwitch(self) -> None: + """do load/unload the model from memory.""" + if self.isModelLoaded: + self._doReleaseModel() + else: + self.loadModelInBackground() + + def toggleImageCaptioning(self, gesture: KeyboardInputGesture) -> None: + """do load/unload the model from memory. + + :param gesture: gesture to toggle this function + """ + self.toggleSwitch() diff --git a/source/_localCaptioner/modelConfig.py b/source/_localCaptioner/modelConfig.py new file mode 100644 index 00000000000..f5b705e482a --- /dev/null +++ b/source/_localCaptioner/modelConfig.py @@ -0,0 +1,277 @@ +# A part of NonVisual Desktop Access (NVDA) +# Copyright (C) 2025 NV Access Limited, Tianze +# This file may be used under the terms of the GNU General Public License, version 2 or later, as modified by the NVDA license. +# For full terms and any additional permissions, see the NVDA license file: https://github.com/nvaccess/nvda/blob/master/copying.txt + +from dataclasses import dataclass, fields, replace +from typing import Type + + +@dataclass(frozen=True) +class _EncoderConfig: + """Configuration for Vision Transformer encoder. + + Based on the Vision Transformer (ViT) specification: + https://arxiv.org/abs/2010.11929 + + HuggingFace ViT configuration: + https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTConfig + + Note: Variable names follow the original specification and HuggingFace conventions + rather than lowerCamelCase to maintain compatibility with pretrained models. + """ + + image_size: int = 224 + num_channels: int = 3 + patch_size: int = 16 + hidden_size: int = 768 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + intermediate_size: int = 3072 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.0 + attention_probs_dropout_prob: float = 0.0 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-12 + encoder_stride: int = 16 + qkv_bias: bool = True + model_type: str = "vit" + # Additional fields from HuggingFace config + add_cross_attention: bool = False + is_decoder: bool = False + is_encoder_decoder: bool = False + chunk_size_feed_forward: int = 0 + cross_attention_hidden_size: int | None = None + finetuning_task: str | None = None + output_attentions: bool = False + output_hidden_states: bool = False + return_dict: bool = True + pruned_heads: dict[str, list[int]] | None = None + tie_word_embeddings: bool = True + torch_dtype: str | None = None + torchscript: bool = False + use_bfloat16: bool = False + + +@dataclass(frozen=True) +class _DecoderConfig: + """Configuration for GPT-2 decoder. + + Based on the GPT-2 specification: + https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf + + HuggingFace GPT-2 configuration: + https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Config + + Note: Variable names follow the original GPT-2 and HuggingFace conventions + rather than lowerCamelCase to maintain compatibility with pretrained models. + """ + + vocab_size: int = 50257 + n_embd: int = 768 + n_layer: int = 12 + n_head: int = 12 + n_ctx: int = 1024 + n_positions: int = 1024 + n_inner: int | None = None + activation_function: str = "gelu_new" + resid_pdrop: float = 0.1 + embd_pdrop: float = 0.1 + attn_pdrop: float = 0.1 + layer_norm_epsilon: float = 1e-05 + initializer_range: float = 0.02 + model_type: str = "gpt2" + # Generation parameters + max_length: int = 20 + min_length: int = 0 + do_sample: bool = False + early_stopping: bool = False + num_beams: int = 1 + num_beam_groups: int = 1 + diversity_penalty: float = 0.0 + temperature: float = 1.0 + top_k: int = 50 + top_p: float = 1.0 + typical_p: float = 1.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + no_repeat_ngram_size: int = 0 + encoder_no_repeat_ngram_size: int = 0 + num_return_sequences: int = 1 + # Cross attention + add_cross_attention: bool = True + is_decoder: bool = True + is_encoder_decoder: bool = False + # Token IDs + bos_token_id: int = 50256 + eos_token_id: int = 50256 + pad_token_id: int = 50256 + decoder_start_token_id: int = 50256 + # Additional configuration + chunk_size_feed_forward: int = 0 + cross_attention_hidden_size: int | None = None + bad_words_ids: list[int] | None = None + begin_suppress_tokens: list[int] | None = None + forced_bos_token_id: int | None = None + forced_eos_token_id: int | None = None + suppress_tokens: list[int] | None = None + exponential_decay_length_penalty: float | None = None + remove_invalid_values: bool = False + return_dict_in_generate: bool = False + output_attentions: bool = False + output_hidden_states: bool = False + output_scores: bool = False + use_cache: bool = True + # Labels + id2label: dict[str, str] | None = None + label2id: dict[str, int] | None = None + # Scaling and attention + reorder_and_upcast_attn: bool = False + scale_attn_by_inverse_layer_idx: bool = False + scale_attn_weights: bool = True + # Summary configuration + summary_activation: str | None = None + summary_first_dropout: float = 0.1 + summary_proj_to_labels: bool = True + summary_type: str = "cls_index" + summary_use_proj: bool = True + # Task specific parameters + task_specific_params: dict[str, any] | None = None + # Other configurations + finetuning_task: str | None = None + prefix: str | None = None + problem_type: str | None = None + pruned_heads: dict[str, list[int]] | None = None + sep_token_id: int | None = None + tf_legacy_loss: bool = False + tie_encoder_decoder: bool = False + tie_word_embeddings: bool = True + tokenizer_class: str | None = None + torch_dtype: str | None = None + torchscript: bool = False + use_bfloat16: bool = False + + +@dataclass(frozen=True) +class _GenerationConfig: + """Configuration for text generation parameters. + + Based on HuggingFace GenerationConfig: + https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig + + Note: Variable names follow HuggingFace conventions rather than lowerCamelCase + to maintain compatibility with the transformers library. + """ + + do_sample: bool = False + num_beams: int = 1 + temperature: float = 1.0 + top_k: int = 50 + top_p: float = 1.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + max_length: int = 20 + min_length: int = 0 + early_stopping: bool = False + diversity_penalty: float = 0.0 + num_beam_groups: int = 1 + no_repeat_ngram_size: int = 0 + num_return_sequences: int = 1 + + +@dataclass(frozen=True) +class _ModelConfig: + """Main model configuration. + + Based on HuggingFace VisionEncoderDecoderConfig: + https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder#transformers.VisionEncoderDecoderConfig + + Note: Variable names follow HuggingFace conventions rather than lowerCamelCase + to maintain compatibility with pretrained models. + """ + + model_type: str = "vision-encoder-decoder" + is_encoder_decoder: bool = True + tie_word_embeddings: bool = False + bos_token_id: int = 50256 + eos_token_id: int = 50256 + pad_token_id: int = 50256 + decoder_start_token_id: int = 50256 + transformers_version: str = "4.33.0.dev0" + architectures: list[str] | None = None + + +@dataclass(frozen=True) +class _PreprocessorConfig: + """Configuration for image preprocessing. + + Based on HuggingFace ViTFeatureExtractor/ViTImageProcessor: + https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTFeatureExtractor + https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTImageProcessor + + Note: Variable names follow HuggingFace conventions rather than lowerCamelCase + to maintain compatibility with the transformers library. + """ + + do_normalize: bool = True + do_rescale: bool = True + do_resize: bool = True + feature_extractor_type: str = "ViTFeatureExtractor" + image_processor_type: str = "ViTFeatureExtractor" + image_mean: list[float] | None = None + image_std: list[float] | None = None + resample: int = 2 # PIL.Image.LANCZOS + rescale_factor: float = 0.00392156862745098 # 1/255 + size: dict[str, int] | None = None + + def __post_init__(self): + """Initialize default values for mutable fields.""" + if self.image_mean is None: + object.__setattr__(self, "image_mean", [0.5, 0.5, 0.5]) + if self.image_std is None: + object.__setattr__(self, "image_std", [0.5, 0.5, 0.5]) + if self.size is None: + object.__setattr__(self, "size", {"height": 224, "width": 224}) + + +# Default configuration instances +_DEFAULT_ENCODER_CONFIG: _EncoderConfig | None = None +_DEFAULT_DECODER_CONFIG: _DecoderConfig | None = None +_DEFAULT_GENERATION_CONFIG: _GenerationConfig | None = None +_DEFAULT_MODEL_CONFIG: _ModelConfig | None = None +_DEFAULT_PREPROCESSOR_CONFIG: _PreprocessorConfig | None = None + + +def initialize(): + global \ + _DEFAULT_ENCODER_CONFIG, \ + _DEFAULT_DECODER_CONFIG, \ + _DEFAULT_GENERATION_CONFIG, \ + _DEFAULT_MODEL_CONFIG, \ + _DEFAULT_PREPROCESSOR_CONFIG + _DEFAULT_ENCODER_CONFIG = _EncoderConfig() + _DEFAULT_DECODER_CONFIG = _DecoderConfig() + _DEFAULT_GENERATION_CONFIG = _GenerationConfig() + _DEFAULT_MODEL_CONFIG = _ModelConfig() + _DEFAULT_PREPROCESSOR_CONFIG = _PreprocessorConfig() + + +def _createConfigFromDict[T]( + configClass: Type[T], + configdict: dict[str, str | int | float | bool | list | dict | None], + defaultConfig: T, +) -> T: + """Create a dataclass instance from a dictionary with automatic field mapping. + + :param configClass: The dataclass type to create + :param configdict: dictionary containing configuration values + :param defaultConfig: Default configuration instance + :return: New dataclass instance with values from configdict or defaults + """ + # Get all field names from the dataclass + fieldNames = {f.name for f in fields(configClass)} + + # Filter configdict to only include valid field names + validUpdates = {fieldName: value for fieldName, value in configdict.items() if fieldName in fieldNames} + + return replace(defaultConfig, **validUpdates) diff --git a/source/_localCaptioner/modelDownloader.py b/source/_localCaptioner/modelDownloader.py new file mode 100644 index 00000000000..1f725b95f87 --- /dev/null +++ b/source/_localCaptioner/modelDownloader.py @@ -0,0 +1,760 @@ +# A part of NonVisual Desktop Access (NVDA) +# Copyright (C) 2025 NV Access Limited, Tianze +# This file may be used under the terms of the GNU General Public License, version 2 or later, as modified by the NVDA license. +# For full terms and any additional permissions, see the NVDA license file: https://github.com/nvaccess/nvda/blob/master/copying.txt + +""" +Multi‑threaded model downloader + +Download ONNX / tokenizer assets from *Hugging Face* (or any HTTP host) +with progress callbacks. Refactored to use requests library. +""" + +import os +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from collections.abc import Callable + +import requests +from requests.adapters import HTTPAdapter +from requests.exceptions import RequestException +from requests.models import Response +from urllib3.util.retry import Retry + +from logHandler import log +import config +from NVDAState import WritePaths + +# Type definitions +ProgressCallback = Callable[[str, int, int, float], None] + +# Constants +CHUNK_SIZE: int = 8_192 +MAX_RETRIES: int = 3 +BACKOFF_BASE: int = 2 # Base delay (in seconds) for exponential backoff strategy + + +class ModelDownloader: + """Multi-threaded model downloader with progress tracking and retry logic.""" + + def __init__( + self, + remoteHost: str = "huggingface.co", + maxWorkers: int = 4, + maxRetries: int = MAX_RETRIES, + ): + """ + Initialize the ModelDownloader. + + :param remoteHost: Remote host URL (default: huggingface.co). + :param maxWorkers: Maximum number of worker threads. + :param maxRetries: Maximum retry attempts per file. + """ + self.remoteHost = remoteHost + self.maxWorkers = maxWorkers + self.maxRetries = maxRetries + + # Thread control + self.cancelRequested = False + self.downloadLock = threading.Lock() + self.activeFutures = set() + + # Configure requests session with retry strategy and automatic redirects + self.session = requests.Session() + + # Configure retry strategy + retryStrategy = Retry( + # Maximum number of retries before giving up + total=maxRetries, + # Base factor for calculating delay between retries + backoff_factor=BACKOFF_BASE, + # HTTP status codes that trigger a retry + status_forcelist=[429, 500, 502, 503, 504], + # HTTP methods allowed to retry + allowed_methods=["HEAD", "GET", "OPTIONS"], + ) + + adapter = HTTPAdapter(max_retries=retryStrategy) + self.session.mount("https://", adapter) + + def requestCancel(self) -> None: + """Request cancellation of all active downloads.""" + log.debug("Cancellation requested") + self.cancelRequested = True + + # Cancel all active futures + with self.downloadLock: + for future in self.activeFutures: + if not future.done(): + future.cancel() + self.activeFutures.clear() + + def resetCancellation(self) -> None: + """Reset cancellation state for new download session.""" + with self.downloadLock: + self.cancelRequested = False + self.activeFutures.clear() + + def ensureModelsDirectory(self) -> str: + """ + Ensure the *models* directory exists (``../../models`` relative to *basePath*). + + :return: Absolute path of the *models* directory. + :raises OSError: When the directory cannot be created. + """ + modelsDir = os.path.abspath(config.conf["automatedImageDescriptions"]["defaultModel"]) + + try: + Path(modelsDir).mkdir(parents=True, exist_ok=True) + except OSError as err: + raise OSError(f"Failed to create models directory {modelsDir}: {err}") from err + else: + log.debug(f"Models directory ensured: {modelsDir}") + return modelsDir + + def constructDownloadUrl( + self, + modelName: str, + filePath: str, + resolvePath: str = "/resolve/main", + ) -> str: + """ + Construct a full download URL for *Hugging Face‑style* repositories. + + :param modelName: Model repository name, e.g. ``Xenova/vit-gpt2-image-captioning``. + :param filePath: Path inside the repo. + :param resolvePath: The branch / ref path, default ``/resolve/main``. + :return: Complete download URL. + """ + remoteHost = self.remoteHost + if not remoteHost.startswith(("http://", "https://")): + remoteHost = f"https://{remoteHost}" + + base = remoteHost.rstrip("/") + model = modelName.strip("/") + ref = resolvePath.strip("/") + filePath = filePath.lstrip("/") + + return f"{base}/{model}/{ref}/{filePath}" + + def _getRemoteFileSize(self, url: str) -> int: + """ + Get remote file size using HEAD request with automatic redirect handling. + + :param url: Remote URL. + :return: File size in bytes, 0 if unable to determine. + """ + if self.cancelRequested: + return 0 + + try: + # Use HEAD request with automatic redirect following + response = self.session.head(url, timeout=10, allow_redirects=True) + response.raise_for_status() + except Exception as e: + if not self.cancelRequested: + log.warning(f"Failed to get remote file size (HEAD) for {url}: {e}") + else: + contentLength = response.headers.get("Content-Length") + if contentLength: + return int(contentLength) + + try: + # If HEAD doesn't work, try GET with range header to get just 1 byte + response = self.session.get(url, headers={"Range": "bytes=0-0"}, timeout=10, allow_redirects=True) + except Exception as e: + if not self.cancelRequested: + log.warning(f"Failed to get remote file size (GET) for {url}: {e}") + else: + if response.status_code == 206: # Partial content + contentRange = response.headers.get("Content-Range", "") + if contentRange and "/" in contentRange: + return int(contentRange.split("/")[-1]) + + return 0 + + def _reportProgress( + self, + callback: ProgressCallback | None, + fileName: str, + downloaded: int, + total: int, + lastReported: int, + ) -> int: + """ + Report download progress if conditions are met. + + :param callback: Progress callback function. + :param fileName: Name of the file being downloaded. + :param downloaded: Bytes downloaded so far. + :param total: Total file size in bytes. + :param lastReported: Last reported download amount. + :return: New lastReported value. + """ + if not callback or total == 0 or self.cancelRequested: + return lastReported + + percent = downloaded / total * 100 + + # Report progress every 1 MiB or 1% or when complete + if ( + downloaded - lastReported >= 1_048_576 # 1 MiB + or abs(percent - lastReported / total * 100) >= 1.0 + or downloaded == total + ): + callback(fileName, downloaded, total, percent) + return downloaded + + return lastReported + + def downloadSingleFile( + self, + url: str, + localPath: str, + progressCallback: ProgressCallback | None = None, + ) -> tuple[bool, str]: + """ + Download a single file with resume support and automatic redirect handling. + + :param url: Remote URL to download from. + :param localPath: Local file path to save the downloaded file. + :param progressCallback: Optional callback function for progress reporting. + :return: Tuple of (success_flag, status_message). + :raises OSError: When directory creation fails. + :raises requests.exceptions.RequestException: When network request fails. + :raises Exception: When unexpected errors occur during download. + """ + if self.cancelRequested: + return False, "Download cancelled" + + threadId = threading.current_thread().ident or 0 + fileName = os.path.basename(localPath) + + # Create destination directory + success, message = self._createDestinationDirectory(localPath) + if not success: + return False, message + + # Get remote file size with redirect handling + remoteSize = self._getRemoteFileSize(url) + + if self.cancelRequested: + return False, "Download cancelled" + + # Check if file already exists and is complete + success, message = self._checkExistingFile( + localPath, + remoteSize, + fileName, + progressCallback, + threadId, + ) + if success is not None: + return success, message + + # Attempt download with retries + return self._downloadWithRetries(url, localPath, fileName, threadId, progressCallback) + + def _createDestinationDirectory(self, localPath: str) -> tuple[bool, str]: + """ + Create destination directory if it doesn't exist. + + :param localPath: Local file path to create directory for. + :return: Tuple of (success_flag, error_message). + :raises OSError: When directory creation fails due to permissions or disk space. + """ + try: + Path(os.path.dirname(localPath)).mkdir(parents=True, exist_ok=True) + return True, "" + except OSError as err: + return False, f"Failed to create directory {localPath}: {err}" + + def _checkExistingFile( + self, + localPath: str, + remoteSize: int, + fileName: str, + progressCallback: ProgressCallback | None, + threadId: int, + ) -> tuple[bool | None, str]: + """ + Check if file already exists and is complete. + + :param localPath: Local file path to check. + :param remoteSize: Size of remote file in bytes. + :param fileName: Base name of the file for progress reporting. + :param progressCallback: Optional callback function for progress reporting. + :param threadId: Current thread identifier for logging. + :return: Tuple of (completion_status, status_message). None status means download should continue. + :raises OSError: When file operations fail. + """ + if not os.path.exists(localPath): + return None, "" + + localSize = os.path.getsize(localPath) + log.debug(f"localSize: {localSize}, remoteSize: {remoteSize}") + + if remoteSize > 0: + if localSize == remoteSize: + if progressCallback and not self.cancelRequested: + progressCallback(fileName, localSize, localSize, 100.0) + log.debug(f"File already complete: {localPath}") + return True, f"File already complete: {localPath}" + elif localSize > remoteSize: + # Local file is larger than remote, may be corrupted + log.warning(f"Local file larger than remote, removing: {localPath}") + try: + os.remove(localPath) + except OSError: + pass + else: + # Cannot get remote size, assume file exists if non-empty + if localSize > 0: + if progressCallback and not self.cancelRequested: + progressCallback(fileName, localSize, localSize, 100.0) + log.debug(f"File already exists (size unknown): {localPath}") + return True, f"File already exists: {localPath}" + + return None, "" + + def _downloadWithRetries( + self, + url: str, + localPath: str, + fileName: str, + threadId: int, + progressCallback: ProgressCallback | None, + ) -> tuple[bool, str]: + """ + Attempt download with retry logic and exponential backoff. + + :param url: Remote URL to download from. + :param localPath: Local file path to save the downloaded file. + :param fileName: Base name of the file for progress reporting. + :param threadId: Current thread identifier for logging. + :param progressCallback: Optional callback function for progress reporting. + :return: Tuple of (success_flag, status_message). + :raises requests.exceptions.HTTPError: When HTTP request fails. + :raises requests.exceptions.RequestException: When network request fails. + :raises Exception: When unexpected errors occur. + """ + for attempt in range(self.maxRetries): + if self.cancelRequested: + return False, "Download cancelled" + + log.debug(f"Downloading (attempt {attempt + 1}/{self.maxRetries}): {url}") + + try: + success, message = self._performSingleDownload( + url, + localPath, + fileName, + threadId, + progressCallback, + ) + + except requests.exceptions.HTTPError as e: + message = self._handleHttpError(e, localPath, fileName, progressCallback, threadId) + if message.startswith("Download completed"): + return True, message + + except RequestException as e: + if self.cancelRequested: + return False, "Download cancelled" + message = f"Request error: {str(e)}" + + except Exception as e: + if self.cancelRequested: + return False, "Download cancelled" + message = f"Unexpected error: {str(e)}" + log.error(message) + + else: + if success: + return True, message + + if not self.cancelRequested: + log.debug(f"{message} – {url}") + if attempt < self.maxRetries - 1: + success = self._waitForRetry(attempt, threadId) + if not success: + return False, "Download cancelled" + else: + return False, message + + return False, "Maximum retries exceeded" + + def _performSingleDownload( + self, + url: str, + localPath: str, + fileName: str, + threadId: int, + progressCallback: ProgressCallback | None, + ) -> tuple[bool, str]: + """ + Perform a single download attempt with resume support. + + :param url: Remote URL to download from. + :param localPath: Local file path to save the downloaded file. + :param fileName: Base name of the file for progress reporting. + :param threadId: Current thread identifier for logging. + :param progressCallback: Optional callback function for progress reporting. + :return: Tuple of (success_flag, status_message). + :raises requests.exceptions.HTTPError: When HTTP request fails. + :raises requests.exceptions.RequestException: When network request fails. + :raises Exception: When file operations or unexpected errors occur. + """ + # Check for existing partial file + resumePos = self._getResumePosition(localPath, threadId) + + # Get response with resume support + response = self._getDownloadResponse(url, resumePos, localPath, threadId) + + if self.cancelRequested: + return False, "Download cancelled" + + try: + # Determine total file size + total = self._calculateTotalSize(response, resumePos) + + if total > 0: + log.debug(f"Total file size: {total:,} bytes") + + # Download file content + success, message = self._downloadFileContent( + response, + localPath, + fileName, + resumePos, + total, + progressCallback, + ) + + if not success: + return False, message + + # Verify download integrity + return self._verifyDownloadIntegrity(localPath, fileName, total, progressCallback, threadId) + + finally: + response.close() + + def _getResumePosition(self, localPath: str, threadId: int) -> int: + """ + Get resume position for partial download. + + :param localPath: Local file path to check. + :param threadId: Current thread identifier for logging. + :return: Byte position to resume from. + :raises OSError: When file operations fail. + """ + resumePos = 0 + if os.path.exists(localPath): + resumePos = os.path.getsize(localPath) + log.debug(f"Resuming from byte {resumePos}") + return resumePos + + def _getDownloadResponse(self, url: str, resumePos: int, localPath: str, threadId: int) -> Response: + """ + Get download response with resume support and redirect handling. + + :param url: Remote URL to download from. + :param resumePos: Byte position to resume from. + :param localPath: Local file path for cleanup if needed. + :param threadId: Current thread identifier for logging. + :return: HTTP response object. + :raises requests.exceptions.HTTPError: When HTTP request fails. + :raises requests.exceptions.RequestException: When network request fails. + :raises Exception: When download is cancelled. + """ + # Set up headers for resume + headers = {} + if resumePos > 0: + headers["Range"] = f"bytes={resumePos}-" + + # Make request with automatic redirect handling + response = self.session.get( + url, + headers=headers, + stream=True, + timeout=10, + allow_redirects=True, + ) + + # Check if resume is supported + if resumePos > 0 and response.status_code != 206: + log.debug("Server doesn't support resume, starting from beginning") + if os.path.exists(localPath): + try: + os.remove(localPath) + except OSError: + pass + + if self.cancelRequested: + response.close() + raise Exception("Download cancelled") + + # Make new request without range header + response.close() + response = self.session.get(url, stream=True, timeout=10, allow_redirects=True) + + response.raise_for_status() + return response + + def _calculateTotalSize(self, response: Response, resumePos: int) -> int: + """ + Calculate total file size from HTTP response headers. + + :param response: HTTP response object. + :param resumePos: Byte position resumed from. + :return: Total file size in bytes. + :raises ValueError: When Content-Range header is malformed. + """ + if response.status_code == 206: + # Partial content response + contentRange = response.headers.get("Content-Range", "") + if contentRange and "/" in contentRange: + return int(contentRange.split("/")[-1]) + else: + return int(response.headers.get("Content-Length", "0")) + resumePos + else: + return int(response.headers.get("Content-Length", "0")) + + def _downloadFileContent( + self, + response, + localPath: str, + fileName: str, + resumePos: int, + total: int, + progressCallback: ProgressCallback | None, + ) -> tuple[bool, str]: + """ + Download file content with progress reporting and cancellation support. + + :param response: HTTP response object to read from. + :param localPath: Local file path to write to. + :param fileName: Base name of the file for progress reporting. + :param resumePos: Byte position resumed from. + :param total: Total file size in bytes. + :param progressCallback: Optional callback function for progress reporting. + :return: Tuple of (success_flag, error_message). + :raises OSError: When file write operations fail. + :raises Exception: When download is cancelled or unexpected errors occur. + """ + downloaded = resumePos + lastReported = downloaded + mode = "ab" if resumePos > 0 else "wb" + + try: + with open(localPath, mode) as fh: + for chunk in response.iter_content(chunk_size=CHUNK_SIZE): + if self.cancelRequested: + return False, "Download cancelled" + + if chunk: # filter out keep-alive chunks + fh.write(chunk) + downloaded += len(chunk) + + if total > 0: + lastReported = self._reportProgress( + progressCallback, + fileName, + downloaded, + total, + lastReported, + ) + except Exception as e: + return False, f"Failed to write file: {str(e)}" + + return True, "" + + def _verifyDownloadIntegrity( + self, + localPath: str, + fileName: str, + total: int, + progressCallback: ProgressCallback | None, + threadId: int, + ) -> tuple[bool, str]: + """ + Verify download integrity and report final progress. + + :param localPath: Local file path to verify. + :param fileName: Base name of the file for progress reporting. + :param total: Expected total file size in bytes. + :param progressCallback: Optional callback function for progress reporting. + :param threadId: Current thread identifier for logging. + :return: Tuple of (success_flag, status_message). + :raises OSError: When file operations fail. + """ + if self.cancelRequested: + return False, "Download cancelled" + + actualSize = os.path.getsize(localPath) + + if actualSize == 0: + return False, "Downloaded file is empty" + + if total > 0 and actualSize != total: + return False, f"File incomplete: {actualSize}/{total} bytes downloaded" + + # Final progress callback + if progressCallback and not self.cancelRequested: + progressCallback(fileName, actualSize, max(total, actualSize), 100.0) + + log.debug(f"Successfully downloaded: {localPath}") + return True, "Download completed" + + def _handleHttpError( + self, + error: requests.exceptions.HTTPError, + localPath: str, + fileName: str, + progressCallback: ProgressCallback | None, + threadId: int, + ) -> str: + """ + Handle HTTP errors with special handling for range not satisfiable. + + :param error: HTTP error exception. + :param localPath: Local file path to check for completion. + :param fileName: Base name of the file for progress reporting. + :param progressCallback: Optional callback function for progress reporting. + :param threadId: Current thread identifier for logging. + :return: Error message or completion status. + :raises OSError: When file operations fail. + """ + if error.response is not None and error.response.status_code == 416: # Range Not Satisfiable + if os.path.exists(localPath): + actualSize = os.path.getsize(localPath) + if actualSize > 0: + log.debug(f"File appears to be complete: {localPath}") + if progressCallback and not self.cancelRequested: + progressCallback(fileName, actualSize, actualSize, 100.0) + return "Download completed" + + return f"HTTP {error.response.status_code if error.response else 'Error'}: {str(error)}" + + def _waitForRetry(self, attempt: int, threadId: int) -> bool: + """ + Wait for retry with exponential backoff and cancellation support. + + :param attempt: Current retry attempt number. + :param threadId: Current thread identifier for logging. + :return: True if wait completed, False if cancelled. + """ + wait = BACKOFF_BASE**attempt + log.debug(f"Waiting {wait}s before retry...") + + for _ in range(wait): + if self.cancelRequested: + return False + time.sleep(1) + + return True + + def downloadModelsMultithreaded( + self, + modelsDir: str = WritePaths.modelsDir, + modelName: str = "Mozilla/distilvit", + filesToDownload: list[str] | None = None, + resolvePath: str = "/resolve/main", + progressCallback: ProgressCallback | None = None, + ) -> tuple[list[str], list[str]]: + """ + Download multiple model assets concurrently. + + :param modelsDir: Base *models* directory. + :param modelName: Repository name. + :param filesToDownload: Explicit file list; None uses common defaults. + :param resolvePath: Branch / ref path. + :param progressCallback: Optional progress callback. + :return: (successful_paths, failed_paths) tuple. + """ + if not self.remoteHost or not modelName: + raise ValueError("remoteHost and modelName cannot be empty") + + filesToDownload = filesToDownload or [ + "onnx/encoder_model_quantized.onnx", + "onnx/decoder_model_merged_quantized.onnx", + "config.json", + "vocab.json", + "preprocessor_config.json", + ] + + if not filesToDownload: + raise ValueError("filesToDownload cannot be empty") + + log.debug( + f"Starting download of {len(filesToDownload)} files for model: {modelName}\n" + f"Remote host: {self.remoteHost}\nMax workers: {self.maxWorkers}", + ) + + localModelDir = os.path.join(modelsDir, modelName) + successful: list[str] = [] + failed: list[str] = [] + + with ThreadPoolExecutor(max_workers=self.maxWorkers) as executor: + futures = [] + + for path in filesToDownload: + if self.cancelRequested: + break + + future = executor.submit( + self.downloadSingleFile, + self.constructDownloadUrl(modelName, path, resolvePath), + os.path.join(localModelDir, path), + progressCallback, + ) + futures.append((future, path)) + + # Track active futures for cancellation + with self.downloadLock: + self.activeFutures.add(future) + + # Process completed futures + for future, filePath in futures: + if self.cancelRequested: + # Cancel remaining futures but don't wait for them + with self.downloadLock: + for f, _ in futures: + if not f.done(): + f.cancel() + break + + # Remove from active futures tracking + with self.downloadLock: + self.activeFutures.discard(future) + + try: + ok, msg = future.result() + if ok: + successful.append(filePath) + log.debug(f"successful {filePath=}") + else: + failed.append(filePath) + log.debug(f"failed: {filePath} - {msg}") + except Exception as err: + failed.append(filePath) + log.debug(f"failed: {filePath} – {err}") + + # Summary + if not self.cancelRequested: + log.debug(f"Total: {len(filesToDownload)}") + log.debug(f"Successful: {len(successful)}") + log.debug(f"Failed: {len(failed)}") + log.debug(f"\nLocal model directory: {localModelDir}") + else: + log.debug("Download cancelled by user") + + return successful, failed + + def __del__(self): + """Clean up the session when the downloader is destroyed.""" + if hasattr(self, "session"): + self.session.close() diff --git a/source/config/__init__.py b/source/config/__init__.py index 93c66f28c61..0fbd3cdb98e 100644 --- a/source/config/__init__.py +++ b/source/config/__init__.py @@ -416,6 +416,7 @@ class ConfigManager(object): "development", "addonStore", "remote", + "automatedImageDescriptions", "math", "screenCurtain", } diff --git a/source/config/configSpec.py b/source/config/configSpec.py index da4c436b58d..9790082625a 100644 --- a/source/config/configSpec.py +++ b/source/config/configSpec.py @@ -97,7 +97,7 @@ reportLiveRegions = featureFlag(optionsEnum="BoolFlag", behaviorOfDefault="enabled") fontFormattingDisplay = featureFlag(optionsEnum="FontFormattingBrailleModeFlag", behaviorOfDefault="LIBLOUIS") [[auto]] - excludedDisplays = string_list(default=list("dotPad")) + excludedDisplays = string_list(default=list("dotPad")) # Braille display driver settings [[__many__]] @@ -373,23 +373,23 @@ [math] [[speech]] # LearningDisability, Blindness, LowVision - impairment = string(default="Blindness") + impairment = string(default="Blindness") # any known language code and sub-code -- could be en-uk, etc - language = string(default="Auto") + language = string(default="Auto") # Any known speech style (falls back to ClearSpeak) - speechStyle = string(default="ClearSpeak") + speechStyle = string(default="ClearSpeak") # Terse, Medium, Verbose - verbosity = string(default="Medium") + verbosity = string(default="Medium") # Change from text speech rate (%) - mathRate = integer(default=100) + mathRate = integer(default=100) # Change from normal pause length (%) - pauseFactor = integer(default=100) + pauseFactor = integer(default=100) # make a sound when starting/ending math speech -- None, Beep - speechSound = string(default="None") + speechSound = string(default="None") # NOTE: not currently working in MathCAT - subjectArea = string(default="General") + subjectArea = string(default="General") # SpellOut (H 2 0), AsCompound (Water) -- not implemented, Off (H sub 2 O) - chemistry = string(default="SpellOut") + chemistry = string(default="SpellOut") # Verbose, Brief, SuperBrief mathSpeak = string(default="Verbose") @@ -529,6 +529,10 @@ # Auto, '.', ',', Custom decimalSeparator = string(default="Auto") +[automatedImageDescriptions] + enable = boolean(default=false) + defaultModel = string(default="Mozilla/distilvit") + [screenCurtain] enabled = boolean(default=false) warnOnLoad = boolean(default=true) diff --git a/source/core.py b/source/core.py index a6ff433c74f..113b88ac5bd 100644 --- a/source/core.py +++ b/source/core.py @@ -555,6 +555,7 @@ def _handleNVDAModuleCleanupBeforeGUIExit(): import globalPluginHandler import watchdog import _remoteClient + import _localCaptioner try: import updateCheck @@ -573,6 +574,8 @@ def _handleNVDAModuleCleanupBeforeGUIExit(): # Terminating remoteClient causes it to clean up its menus, so do it here while they still exist _terminate(_remoteClient) + _terminate(_localCaptioner) + def _initializeObjectCaches(): """ @@ -916,6 +919,10 @@ def main(): _remoteClient.initialize() + import _localCaptioner + + _localCaptioner.initialize() + if globalVars.appArgs.install or globalVars.appArgs.installSilent: import gui.installerGui diff --git a/source/globalCommands.py b/source/globalCommands.py index 92d009eb518..25984a1ba35 100755 --- a/source/globalCommands.py +++ b/source/globalCommands.py @@ -72,6 +72,7 @@ import synthDriverHandler from utils.displayString import DisplayStringEnum import _remoteClient +import _localCaptioner #: Script category for text review commands. # Translators: The name of a category of NVDA commands. @@ -124,6 +125,9 @@ #: Script category for Remote Access commands. # Translators: The name of a category of NVDA commands. SCRCAT_REMOTE = pgettext("remote", "Remote Access") +#: Script category for image description commands. +# Translators: The name of a category of NVDA commands. +SCRCAT_IMAGE_DESC = pgettext("imageDesc", "Image Descriptions") # Translators: Reported when there are no settings to configure in synth settings ring # (example: when there is no setting for language). @@ -3517,6 +3521,15 @@ def script_activateDocumentFormattingDialog(self, gesture): def script_activateRemoteAccessSettings(self, gesture: "inputCore.InputGesture"): wx.CallAfter(gui.mainFrame.onRemoteAccessSettingsCommand, None) + @script( + # Translators: Input help mode message for go to local captioner settings command. + description=pgettext("imageDesc", "Shows the AI image descriptions settings"), + category=SCRCAT_CONFIG, + ) + @gui.blockAction.when(gui.blockAction.Context.MODAL_DIALOG_OPEN) + def script_activateLocalCaptionerSettings(self, gesture: "inputCore.InputGesture"): + wx.CallAfter(gui.mainFrame.onLocalCaptionerSettingsCommand, None) + @script( # Translators: Input help mode message for go to Add-on Store settings command. description=_("Shows NVDA's Add-on Store settings"), @@ -5143,6 +5156,30 @@ def script_repeatLastSpokenInformation(self, gesture: "inputCore.InputGesture") title = _("Last spoken information") ui.browseableMessage(lastSpeechText, title, copyButton=True, closeButton=True) + @script( + description=pgettext( + "imageDesc", + # Translators: Description for the image caption script + "Get an AI-generated image description of the navigator object.", + ), + category=SCRCAT_IMAGE_DESC, + gesture="kb:NVDA+g", + ) + @gui.blockAction.when(gui.blockAction.Context.SCREEN_CURTAIN) + def script_runCaption(self, gesture: "inputCore.InputGesture"): + _localCaptioner._localCaptioner.runCaption(gesture) + + @script( + description=pgettext( + "imageDesc", + # Translators: Description for the toggle image captioning script + "Load or unload the image captioner", + ), + category=SCRCAT_IMAGE_DESC, + ) + def script_toggleImageCaptioning(self, gesture: "inputCore.InputGesture"): + _localCaptioner._localCaptioner.toggleImageCaptioning(gesture) + #: The single global commands instance. #: @type: L{GlobalCommands} diff --git a/source/gui/__init__.py b/source/gui/__init__.py index 023e787e177..5391ccce5d3 100644 --- a/source/gui/__init__.py +++ b/source/gui/__init__.py @@ -56,6 +56,7 @@ GeneralSettingsPanel, InputCompositionPanel, KeyboardSettingsPanel, + LocalCaptionerSettingsPanel, MouseSettingsPanel, MultiCategorySettingsDialog, NVDASettingsDialog, @@ -387,6 +388,10 @@ def onUwpOcrCommand(self, evt): def onRemoteAccessSettingsCommand(self, evt): self.popupSettingsDialog(NVDASettingsDialog, RemoteSettingsPanel) + @blockAction.when(blockAction.Context.SECURE_MODE) + def onLocalCaptionerSettingsCommand(self, evt): + self.popupSettingsDialog(NVDASettingsDialog, LocalCaptionerSettingsPanel) + @blockAction.when(blockAction.Context.SECURE_MODE) def onAdvancedSettingsCommand(self, evt: wx.CommandEvent): self.popupSettingsDialog(NVDASettingsDialog, AdvancedPanel) diff --git a/source/gui/_localCaptioner/__init__.py b/source/gui/_localCaptioner/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/source/gui/_localCaptioner/messageDialogs.py b/source/gui/_localCaptioner/messageDialogs.py new file mode 100644 index 00000000000..27fa32d7b74 --- /dev/null +++ b/source/gui/_localCaptioner/messageDialogs.py @@ -0,0 +1,198 @@ +# A part of NonVisual Desktop Access (NVDA) +# Copyright (C) 2025 NV Access Limited, Tianze +# This file may be used under the terms of the GNU General Public License, version 2 or later, as modified by the NVDA license. +# For full terms and any additional permissions, see the NVDA license file: https://github.com/nvaccess/nvda/blob/master/copying.txt + +from gui.message import MessageDialog, DefaultButton, ReturnCode, DialogType +import gui +from _localCaptioner.modelDownloader import ModelDownloader, ProgressCallback +import threading +from threading import Thread +import wx +import ui +import _localCaptioner + + +class ImageDescDownloader: + _downloadThread: Thread | None = None + isOpening: bool = False + + def __init__(self): + self.downloadDict: dict[str, tuple[int, int]] = {} + self.modelDownloader: ModelDownloader | None = None + self._shouldCancel = False + self._progressDialog: wx.ProgressDialog | None = None + self.filesToDownload = [ + "onnx/encoder_model_quantized.onnx", + "onnx/decoder_model_merged_quantized.onnx", + "config.json", + "vocab.json", + "preprocessor_config.json", + ] + + def onDownload(self, progressCallback: ProgressCallback) -> None: + self.modelDownloader = ModelDownloader() + (success, fail) = self.modelDownloader.downloadModelsMultithreaded( + filesToDownload=self.filesToDownload, + progressCallback=progressCallback, + ) + if len(fail) == 0: + wx.CallAfter(self.openSuccessDialog) + else: + wx.CallAfter(self.openFailDialog) + + def openSuccessDialog(self) -> None: + confirmationButton = (DefaultButton.OK.value._replace(defaultFocus=True, fallbackAction=True),) + self._stopped() + + dialog = MessageDialog( + parent=None, + # Translators: title of dialog when download successfully + title=pgettext("imageDesc", "Download successful"), + message=pgettext( + "imageDesc", + # Translators: label of dialog when downloading image captioning + "Image captioning installed successfully.", + ), + dialogType=DialogType.STANDARD, + buttons=confirmationButton, + ) + + if dialog.ShowModal() == ReturnCode.OK: + # load image desc after successful download + if not _localCaptioner.isModelLoaded(): + _localCaptioner.toggleImageCaptioning() + + def openFailDialog(self) -> None: + if self._shouldCancel: + return + + confirmationButtons = ( + DefaultButton.YES.value._replace(defaultFocus=True, fallbackAction=False), + DefaultButton.NO.value._replace(defaultFocus=False, fallbackAction=True), + ) + + dialog = MessageDialog( + parent=None, + # Translators: title of dialog when fail to download + title=pgettext("imageDesc", "Download failed"), + message=pgettext( + "imageDesc", + # Translators: label of dialog when fail to download image captioning + "Image captioning download failed. Would you like to retry?", + ), + dialogType=DialogType.WARNING, + buttons=confirmationButtons, + ) + + if dialog.ShowModal() == ReturnCode.YES: + self.doDownload() + else: + self._stopped() + + def openDownloadDialog(self) -> None: + if ImageDescDownloader._downloadThread is not None and ImageDescDownloader._downloadThread.is_alive(): + # Translators: message when image captioning is still downloading + ui.message(pgettext("imageDesc", "image captioning is still downloading, please wait...")) + return + if ImageDescDownloader.isOpening: + return + + confirmationButtons = ( + DefaultButton.YES.value._replace(defaultFocus=True, fallbackAction=False), + DefaultButton.NO.value._replace(defaultFocus=False, fallbackAction=True), + ) + + dialog = MessageDialog( + parent=None, + # Translators: title of dialog when downloading Image captioning + title=pgettext("imageDesc", "Confirm download"), + message=pgettext( + "imageDesc", + # Translators: label of dialog when downloading image captioning + "Image captioning not installed. Would you like to install (178 MB)?", + ), + dialogType=DialogType.WARNING, + buttons=confirmationButtons, + ) + ImageDescDownloader.isOpening = True + + if dialog.ShowModal() == ReturnCode.YES: + self._progressDialog = wx.ProgressDialog( + # Translators: The title of the dialog displayed while downloading image descriptioner. + pgettext("imageDesc", "Downloading Image Descriptioner"), + # Translators: The progress message indicating that a connection is being established. + pgettext("imageDesc", "Connecting"), + style=wx.PD_CAN_ABORT | wx.PD_ELAPSED_TIME | wx.PD_REMAINING_TIME | wx.PD_AUTO_HIDE, + parent=gui.mainFrame, + ) + self.doDownload() + else: + ImageDescDownloader.isOpening = False + + def doDownload(self): + def progressCallback( + fileName: str, + downloadedBytes: int, + totalBytes: int, + _percentage: float, + ) -> None: + """Callback function to capture progress data.""" + self.downloadDict[fileName] = (downloadedBytes, totalBytes) + downloadedSum = sum(d for d, _ in self.downloadDict.values()) + totalSum = sum(t for _, t in self.downloadDict.values()) + ratio = downloadedSum / totalSum if totalSum > 0 else 0.0 + totalProgress = int(ratio * 100) + # update progress when downloading all files to prevent premature stop + if len(self.downloadDict) == len(self.filesToDownload): + # Translators: The progress message indicating that a download is in progress. + cont, skip = self._progressDialog.Update(totalProgress, pgettext("imageDesc", "downloading")) + if not cont: + self._shouldCancel = True + self._stopped() + + ImageDescDownloader._downloadThread = threading.Thread( + target=self.onDownload, + name="ModelDownloadMainThread", + daemon=False, + args=(progressCallback,), + ) + ImageDescDownloader._downloadThread.start() + + def _stopped(self): + self.modelDownloader.requestCancel() + ImageDescDownloader._downloadThread = None + self._progressDialog.Hide() + self._progressDialog.Destroy() + self._progressDialog = None + ImageDescDownloader.isOpening = False + + +def openEnableOnceDialog() -> None: + confirmationButtons = ( + DefaultButton.YES.value._replace(defaultFocus=True, fallbackAction=False), + DefaultButton.NO.value._replace(defaultFocus=False, fallbackAction=True), + ) + + dialog = MessageDialog( + parent=None, + # Translators: title of dialog when enable image desc + title=pgettext("imageDesc", "Enable AI image descriptions"), + message=pgettext( + "imageDesc", + # Translators: label of dialog when enable image desc + "AI image descriptions are currently disabled." + "\n\n" + "Warning: AI image descriptions are experimental. " + "Do not use this feature in circumstances where inaccurate descriptions could cause harm." + "\n\n" + "Would you like to temporarily enable AI image descriptions now?", + ), + dialogType=DialogType.STANDARD, + buttons=confirmationButtons, + ) + + if dialog.ShowModal() == ReturnCode.YES: + # load image desc in this session + if not _localCaptioner.isModelLoaded(): + _localCaptioner.toggleImageCaptioning() diff --git a/source/gui/blockAction.py b/source/gui/blockAction.py index 008ade69abf..21fbe4ab8f3 100644 --- a/source/gui/blockAction.py +++ b/source/gui/blockAction.py @@ -39,6 +39,14 @@ def _isRemoteAccessDisabled() -> bool: return not remoteRunning() +def _isScreenCurtainEnabled() -> bool: + """Whether screen curtain functionality is **enabled**.""" + # Import late to avoid circular import + from screenCurtain import screenCurtain + + return screenCurtain is not None and screenCurtain.enabled + + @dataclass class _Context: blockActionIf: Callable[[], bool] @@ -86,6 +94,11 @@ class Context(_Context, Enum): # Translators: Reported when an action cannot be performed because Remote Access functionality is disabled. pgettext("remote", "Action unavailable when Remote Access is disabled"), ) + SCREEN_CURTAIN = ( + lambda: _isScreenCurtainEnabled(), + # Translators: Reported when an action cannot be performed because screen curtain is enabled. + _("Action unavailable while screen curtain is enabled"), + ) def when(*contexts: Context): diff --git a/source/gui/settingsDialogs.py b/source/gui/settingsDialogs.py index 2370134380a..507504d9ef5 100644 --- a/source/gui/settingsDialogs.py +++ b/source/gui/settingsDialogs.py @@ -4027,6 +4027,53 @@ def onSave(self): _remoteClient.terminate() +class LocalCaptionerSettingsPanel(SettingsPanel): + """Settings panel for Local captioner configuration.""" + + # Translators: This is the label for the local captioner settings panel. + title = pgettext("imageDesc", "AI Image Descriptions") + helpId = "LocalCaptionerSettings" + panelDescription = pgettext( + "imageDesc", + # Translators: This is a label appearing on the AI Image Descriptions settings panel. + "Warning: AI image descriptions are experimental. " + "Do not use this feature in circumstances where inaccurate descriptions could cause harm.", + ) + + def makeSettings(self, settingsSizer: wx.BoxSizer): + """Create the settings controls for the panel. + + :param settingsSizer: The sizer to add settings controls to. + """ + + sHelper = guiHelper.BoxSizerHelper(self, sizer=settingsSizer) + + self.windowText = sHelper.addItem( + wx.StaticText(self, label=self.panelDescription), + ) + self.windowText.Wrap(self.scaleSize(PANEL_DESCRIPTION_WIDTH)) + + self.enable = sHelper.addItem( + # Translators: A configuration in settings dialog. + wx.CheckBox(self, label=pgettext("imageDesc", "Enable image captioner")), + ) + self.enable.SetValue(config.conf["automatedImageDescriptions"]["enable"]) + self.bindHelpEvent("LocalCaptionToggle", self.enable) + + def onSave(self) -> None: + """Save the configuration settings.""" + enabled = self.enable.GetValue() + oldEnabled = config.conf["automatedImageDescriptions"]["enable"] + + if enabled != oldEnabled: + import _localCaptioner + + if enabled != _localCaptioner.isModelLoaded(): + _localCaptioner.toggleImageCaptioning() + + config.conf["automatedImageDescriptions"]["enable"] = enabled + + class TouchInteractionPanel(SettingsPanel): # Translators: This is the label for the touch interaction settings panel. title = _("Touch Interaction") @@ -6089,6 +6136,7 @@ class NVDASettingsDialog(MultiCategorySettingsDialog): DocumentNavigationPanel, MathSettingsPanel, RemoteSettingsPanel, + LocalCaptionerSettingsPanel, ] # In secure mode, add-on update is disabled, so AddonStorePanel should not appear since it only contains # add-on update related controls. @@ -6117,6 +6165,7 @@ def _doOnCategoryChange(self): or isinstance(self.currentCategory, GeneralSettingsPanel) or isinstance(self.currentCategory, AddonStorePanel) or isinstance(self.currentCategory, RemoteSettingsPanel) + or isinstance(self.currentCategory, LocalCaptionerSettingsPanel) or isinstance(self.currentCategory, MathSettingsPanel) or isinstance(self.currentCategory, PrivacyAndSecuritySettingsPanel) ): diff --git a/source/setup.py b/source/setup.py index d7cb15b0029..bba5f0dbcc3 100755 --- a/source/setup.py +++ b/source/setup.py @@ -213,8 +213,6 @@ def _genManifestTemplate(shouldHaveUIAccess: bool) -> tuple[int, int, bytes]: # winxptheme is optionally used by wx.lib.agw.aui. # We don't need this. "winxptheme", - # numpy is an optional dependency of comtypes but we don't require it. - "numpy", # multiprocessing isn't going to work in a frozen environment "multiprocessing", "concurrent.futures.process", @@ -246,6 +244,8 @@ def _genManifestTemplate(shouldHaveUIAccess: bool) -> tuple[int, int, bytes]: "mdx_truly_sane_lists", "mdx_gh_links", "pymdownx", + # Required for local image captioning + "numpy", ], "includes": [ "nvdaBuiltin", @@ -253,6 +253,9 @@ def _genManifestTemplate(shouldHaveUIAccess: bool) -> tuple[int, int, bytes]: "bisect", # robotremoteserver (for system tests) depends on xmlrpc.server "xmlrpc.server", + # required for import numpy without error + "numpy._core._exceptions", + "numpy._core._multiarray_umath", ], }, data_files=[ diff --git a/tests/system/libraries/SystemTestSpy/configManager.py b/tests/system/libraries/SystemTestSpy/configManager.py index fda1830f67d..ae9286af9ea 100644 --- a/tests/system/libraries/SystemTestSpy/configManager.py +++ b/tests/system/libraries/SystemTestSpy/configManager.py @@ -105,6 +105,9 @@ def setupProfile( _pJoin(repoRoot, "tests", "system", "nvdaSettingsFiles", settingsFileName), _pJoin(stagingDir, "nvdaProfile", "nvda.ini"), ) + if _shouldGenerateMockModel(_pJoin(stagingDir, "nvdaProfile", "nvda.ini")): + _configModels(_pJoin(stagingDir, "nvdaProfile", "models", "mock", "vit-gpt2-image-captioning")) + if gesturesFileName is not None: opSys.copy_file( # Despite duplication, specify full paths for clarity. @@ -128,3 +131,26 @@ def teardownProfile(stagingDir: str): _pJoin(stagingDir, "nvdaProfile"), recursive=True, ) + + +def _configModels(modelsDirectory: str) -> None: + from .mockModels import MockVisionEncoderDecoderGenerator + + generator = MockVisionEncoderDecoderGenerator(randomSeed=8) + generator.generateAllFiles(modelsDirectory) + + +def _shouldGenerateMockModel(iniPath: str) -> bool: + # Read original lines + with open(iniPath, "r", encoding="utf-8") as f: + lines = f.readlines() + + for line in lines: + # Detect section headers + stripLine = line.strip() + if stripLine.startswith("[") and stripLine.endswith("]"): + hasCaptionSection = stripLine.lower() == "[automatedimagedescriptions]" + if hasCaptionSection: + return True + else: + continue diff --git a/tests/system/libraries/SystemTestSpy/mockModels.py b/tests/system/libraries/SystemTestSpy/mockModels.py new file mode 100644 index 00000000000..896b3901fb2 --- /dev/null +++ b/tests/system/libraries/SystemTestSpy/mockModels.py @@ -0,0 +1,793 @@ +# A part of NonVisual Desktop Access (NVDA) +# Copyright (C) 2025 NV Access Limited, Tianze +# This file may be used under the terms of the GNU General Public License, version 2 or later. +# For more details see: https://www.gnu.org/licenses/gpl-2.0.html +""" +Mock Vision-Encoder-Decoder Model Generator + +This module provides a class to generate mock ONNX models and configuration files +for a Vision-Encoder-Decoder model (ViT-GPT2 style) used for image captioning. +The generated files can be used for testing and development purposes. +""" + +import os +import json +from pathlib import Path +from typing import Any + +import numpy as np +import onnx +from onnx import helper, TensorProto, numpy_helper + + +class MockVisionEncoderDecoderGenerator: + """ + A class to generate mock ONNX models and configuration files for a + Vision-Encoder-Decoder model architecture. + + This generator creates: + - onnx/encoder_model_quantized.onnx: Vision Transformer encoder + - onnx/decoder_model_merged_quantized.onnx: GPT-2 style decoder + - config.json: Model configuration + - vocab.json: Vocabulary mapping + """ + + def __init__(self, randomSeed: int = 8): + """ + Initialize the mock model generator. + + :param randomSeed (int): Random seed for reproducible weight generation.Defaults to 8. + """ + self.randomSeed = randomSeed + self._setRandomSeed() + + # Model hyperparameters + self.vocab_size = 100 + self.hidden_size = 64 + self.n_layers = 12 + self.image_size = 224 + self.patch_size = 16 + self.num_channels = 3 + + # Derived parameters + self.num_patches = (self.image_size // self.patch_size) ** 2 + + def _setRandomSeed(self) -> None: + """Set random seed for reproducible results.""" + np.random.seed(self.randomSeed) + + def generateAllFiles(self, outputDir: str) -> None: + """ + Generate all mock model files in the specified directory. + + :param outputDir (str): Target directory to create the model files. Will create the directory if it doesn't exist. + """ + outputPath = Path(outputDir) + outputPath.mkdir(parents=True, exist_ok=True) + + # Create onnx subdirectory + onnxDir = outputPath / "onnx" + onnxDir.mkdir(exist_ok=True) + + # Generate all components + self._generateEncoderModel(os.path.join(onnxDir, "encoder_model_quantized.onnx")) + self._generateDecoderModel(os.path.join(onnxDir, "decoder_model_merged_quantized.onnx")) + self._generateConfigFile(os.path.join(outputPath, "config.json")) + self._generateVocabFile(os.path.join(outputPath, "vocab.json")) + + def _generateEncoderModel(self, outputPath: Path) -> None: + """ + Generate the Vision Transformer encoder ONNX model. + + This creates a simplified ViT encoder that performs patch embedding + using convolution followed by reshaping operations. + + :param outputPath (Path): Output path for the encoder ONNX file. + """ + # Define input and output specifications + pixelValues = helper.make_tensor_value_info( + "pixelValues", + TensorProto.FLOAT, + ["batch", self.num_channels, self.image_size, self.image_size], + ) + + patchEmbeds = helper.make_tensor_value_info( + "patchEmbeds", + TensorProto.FLOAT, + ["batch", self.num_patches, self.hidden_size], + ) + + # Generate random but reproducible weights for patch embedding + convWeights = np.random.randn( + self.hidden_size, + self.num_channels, + self.patch_size, + self.patch_size, + ).astype(np.float32) + + convBias = np.zeros(self.hidden_size, dtype=np.float32) + + # Create initializers + weightInit = numpy_helper.from_array(convWeights, "convWeights") + biasInit = numpy_helper.from_array(convBias, "convBias") + + # Shape constant for reshaping + targetShape = np.array([0, self.num_patches, self.hidden_size], dtype=np.int64) + shapeInit = numpy_helper.from_array(targetShape, "targetShape") + + # Define computation nodes + nodes = [ + # Patch embedding using convolution + helper.make_node( + "Conv", + inputs=["pixelValues", "convWeights", "convBias"], + outputs=["conv_output"], + kernel_shape=[self.patch_size, self.patch_size], + strides=[self.patch_size, self.patch_size], + ), + # Transpose to get correct dimension order + # From [batch, hidden_size, patch_h, patch_w] to [batch, patch_h, patch_w, hidden_size] + helper.make_node( + "Transpose", + inputs=["conv_output"], + outputs=["transposed_output"], + perm=[0, 2, 3, 1], + ), + # Reshape to flatten patches + # From [batch, patch_h, patch_w, hidden_size] to [batch, num_patches, hidden_size] + helper.make_node( + "Reshape", + inputs=["transposed_output", "targetShape"], + outputs=["patchEmbeds"], + ), + ] + + # Create and save the model + graph = helper.make_graph( + nodes=nodes, + name="VisionTransformerEncoder", + inputs=[pixelValues], + outputs=[patchEmbeds], + initializer=[weightInit, biasInit, shapeInit], + ) + + model = helper.make_model(graph, producer_name="mock-vit-encoder") + model.opset_import[0].version = 13 + model.ir_version = 10 + + onnx.save(model, str(outputPath)) + + def _generateDecoderModel(self, outputPath: Path) -> None: + """ + Generate the GPT-2 style decoder ONNX model. + + This creates a simplified decoder that accepts multiple inputs including + token IDs, encoder hidden states, cache flags, and past key-value pairs. + + :param outputPath (Path): Output path for the decoder ONNX file. + """ + # Generate fixed random weights for reproducibility + embeddingWeights = np.random.randn( + self.vocab_size, + self.hidden_size, + ).astype(np.float32) + + projectionWeights = np.random.randn( + self.hidden_size, + self.vocab_size, + ).astype(np.float32) + + # Create weight initializers + embInit = numpy_helper.from_array(embeddingWeights, "embeddingWeights") + projInit = numpy_helper.from_array(projectionWeights, "projectionWeights") + + # Define all input specifications + inputs = self._createDecoderInputs() + + # Define output specification + outputs = [ + helper.make_tensor_value_info( + "logits", + TensorProto.FLOAT, + ["batch", "seq", self.vocab_size], + ), + ] + + # Create computation nodes + nodes = self._createDecoderNodes() + + # Create shape and scaling constants + shapeConstants = self._createDecoderConstants() + + # Combine all initializers + initializers = [embInit, projInit] + shapeConstants + + # Create and save the model + graph = helper.make_graph( + nodes=nodes, + name="GPT2DecoderWithCache", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + + model = helper.make_model(graph, producer_name="mock-gpt2-decoder") + model.opset_import[0].version = 13 + model.ir_version = 10 + + onnx.save(model, str(outputPath)) + + def _createDecoderInputs(self) -> list: + """ + Create input specifications for the decoder model. + + :return: list: List of tensor value info objects for all decoder inputs. + """ + inputs = [] + + # Primary inputs + inputs.extend( + [ + helper.make_tensor_value_info( + "input_ids", + TensorProto.INT64, + ["batch", "seq"], + ), + helper.make_tensor_value_info( + "encoder_hidden_states", + TensorProto.FLOAT, + ["batch", "enc_seq_len", self.hidden_size], + ), + helper.make_tensor_value_info( + "use_cache_branch", + TensorProto.BOOL, + ["batch"], + ), + ], + ) + + # Past key-value cache inputs for each layer + for layerIdx in range(self.n_layers): + inputs.extend( + [ + helper.make_tensor_value_info( + f"past_key_values.{layerIdx}.key", + TensorProto.FLOAT, + ["batch", "num_heads", "past_seq_len", self.hidden_size], + ), + helper.make_tensor_value_info( + f"past_key_values.{layerIdx}.value", + TensorProto.FLOAT, + ["batch", "num_heads", "past_seq_len", self.hidden_size], + ), + ], + ) + + return inputs + + def _createDecoderNodes(self) -> list: + """ + Create computation nodes for the decoder model. + + :return: list: List of ONNX nodes defining the decoder computation. + """ + nodes = [] + + # Token embedding lookup + nodes.append( + helper.make_node( + "Gather", + inputs=["embeddingWeights", "input_ids"], + outputs=["token_embeddings"], + axis=0, + ), + ) + + # Process encoder hidden states + nodes.extend(self._createEncoderProcessingNodes()) + + # Process cache branch flag + nodes.extend(self._createCacheProcessingNodes()) + + # Process past key-value pairs + cacheFeatures = self._createCacheFeatureNodes(nodes) + + # Combine all auxiliary features + nodes.extend(self._createFeatureCombinationNodes(cacheFeatures)) + + # Apply main computation pipeline + nodes.extend(self._createMainComputationNodes()) + + return nodes + + def _createEncoderProcessingNodes(self) -> list: + """Create nodes to process encoder hidden states.""" + return [ + # Global average pooling over encoder states + helper.make_node( + "ReduceMean", + inputs=["encoder_hidden_states"], + outputs=["encoder_pooled"], + axes=[1, 2], # Pool over sequence length and hidden dimensions + ), + # Reshape for broadcasting + helper.make_node( + "Reshape", + inputs=["encoder_pooled", "shapeBatch1"], + outputs=["encoder_feature"], + ), + ] + + def _createCacheProcessingNodes(self) -> list: + """Create nodes to process the cache branch flag.""" + return [ + # Convert boolean to float + helper.make_node( + "Cast", + inputs=["use_cache_branch"], + outputs=["cache_flag_float"], + to=TensorProto.FLOAT, + ), + # Reshape for broadcasting + helper.make_node( + "Reshape", + inputs=["cache_flag_float", "shapeBatch1"], + outputs=["cache_flag_feature"], + ), + ] + + def _createCacheFeatureNodes(self, nodes: list) -> list: + """ + Create nodes to process past key-value cache inputs. + + :param nodes (list): List to append new nodes to. + :return: list: Names of cache feature tensors. + """ + cacheFeatures = [] + + for layerIdx in range(self.n_layers): + # Process key cache + nodes.extend( + [ + helper.make_node( + "ReduceMean", + inputs=[f"past_key_values.{layerIdx}.key"], + outputs=[f"cache_key_{layerIdx}_pooled"], + axes=[1, 2, 3], # Global pooling, keep only batch dimension + ), + helper.make_node( + "Reshape", + inputs=[f"cache_key_{layerIdx}_pooled", "shapeBatch1"], + outputs=[f"cache_key_{layerIdx}_feature"], + ), + ], + ) + + # Process value cache + nodes.extend( + [ + helper.make_node( + "ReduceMean", + inputs=[f"past_key_values.{layerIdx}.value"], + outputs=[f"cache_value_{layerIdx}_pooled"], + axes=[1, 2, 3], + ), + helper.make_node( + "Reshape", + inputs=[f"cache_value_{layerIdx}_pooled", "shapeBatch1"], + outputs=[f"cache_value_{layerIdx}_feature"], + ), + ], + ) + + cacheFeatures.extend( + [ + f"cache_key_{layerIdx}_feature", + f"cache_value_{layerIdx}_feature", + ], + ) + + return cacheFeatures + + def _createFeatureCombinationNodes(self, cacheFeatures: list) -> list: + """ + Create nodes to combine all auxiliary features. + + :param cacheFeatures (list): List of cache feature tensor names. + :return: list: Nodes for feature combination. + """ + nodes = [] + allFeatures = ["encoder_feature", "cache_flag_feature"] + cacheFeatures + + # Sequentially add all features together + currentSum = allFeatures[0] + for i, feature in enumerate(allFeatures[1:], 1): + nodes.append( + helper.make_node( + "Add", + inputs=[currentSum, feature], + outputs=[f"combined_features_{i}"], + ), + ) + currentSum = f"combined_features_{i}" + + return nodes + + def _createMainComputationNodes(self) -> list: + """Create the main computation pipeline nodes.""" + finalCombined = f"combined_features_{self.n_layers * 2 + 1}" + + return [ + # Flatten token embeddings + helper.make_node( + "Reshape", + inputs=["token_embeddings", "shape2d"], + outputs=["embeddings_flat"], + ), + # Scale embeddings + helper.make_node( + "Mul", + inputs=["embeddings_flat", "featureScale"], + outputs=["scaled_embeddings"], + ), + # Add auxiliary features (broadcasting) + helper.make_node( + "Add", + inputs=["scaled_embeddings", finalCombined], + outputs=["final_features"], + ), + # Project to vocabulary space + helper.make_node( + "MatMul", + inputs=["final_features", "projectionWeights"], + outputs=["logits_flat"], + ), + # Reshape back to 3D + helper.make_node( + "Reshape", + inputs=["logits_flat", "shape3d"], + outputs=["logits"], + ), + ] + + def _createDecoderConstants(self) -> list: + """ + Create constant tensors needed for decoder computation. + + :returns: list: List of constant tensor initializers. + """ + constants = [] + + # Shape constants for reshaping operations + shape2d = numpy_helper.from_array( + np.array([-1, self.hidden_size], dtype=np.int64), + name="shape2d", + ) + + shape3d = numpy_helper.from_array( + np.array([0, -1, self.vocab_size], dtype=np.int64), + name="shape3d", + ) + + shapeBatch1 = numpy_helper.from_array( + np.array([-1, 1], dtype=np.int64), + name="shapeBatch1", + ) + + # Feature scaling factor + featureScale = numpy_helper.from_array( + np.array([[1.1]], dtype=np.float32), + name="featureScale", + ) + + constants.extend([shape2d, shape3d, shapeBatch1, featureScale]) + + return constants + + def _generateConfigFile(self, outputPath: Path) -> None: + """ + Generate the model configuration JSON file. + + :param outputPath (Path): Output path for the config.json file. + """ + config = self._getModelConfig() + + with open(outputPath, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2, ensure_ascii=False) + + def _getModelConfig(self) -> dict[str, Any]: + """ + Get the complete model configuration dictionary. + + :return: dict[str, Any]: Complete model configuration. + """ + return { + "_name_or_path": "nlpconnect/vit-gpt2-image-captioning", + "architectures": ["VisionEncoderDecoderModel"], + "bos_token_id": 99, + "decoder": self._getDecoderConfig(), + "decoder_start_token_id": 99, + "encoder": self._getEncoderConfig(), + "eos_token_id": 99, + "is_encoder_decoder": True, + "model_type": "vision-encoder-decoder", + "pad_token_id": 99, + "tie_word_embeddings": False, + "transformers_version": "4.33.0.dev0", + } + + def _getDecoderConfig(self) -> dict[str, Any]: + """Get decoder-specific configuration.""" + return { + "_name_or_path": "", + "activation_function": "gelu_new", + "add_cross_attention": True, + "architectures": ["GPT2LMHeadModel"], + "attn_pdrop": 0.1, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": 99, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": 99, + "diversity_penalty": 0.0, + "do_sample": False, + "early_stopping": False, + "embd_pdrop": 0.1, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": 99, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "initializer_range": 0.02, + "is_decoder": True, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_norm_epsilon": 1e-05, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "gpt2", + "n_ctx": 1024, + "n_embd": 768, + "n_head": 12, + "n_inner": None, + "n_layer": 12, + "n_positions": 1024, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": 99, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "remove_invalid_values": False, + "reorder_and_upcast_attn": False, + "repetition_penalty": 1.0, + "resid_pdrop": 0.1, + "return_dict": True, + "return_dict_in_generate": False, + "scale_attn_by_inverse_layer_idx": False, + "scale_attn_weights": True, + "sep_token_id": None, + "summary_activation": None, + "summary_first_dropout": 0.1, + "summary_proj_to_labels": True, + "summary_type": "cls_index", + "summary_use_proj": True, + "suppress_tokens": None, + "task_specific_params": { + "text-generation": { + "do_sample": True, + "max_length": 50, + }, + }, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": None, + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + "use_cache": True, + "vocab_size": self.vocab_size, + } + + def _getEncoderConfig(self) -> dict[str, Any]: + """Get encoder-specific configuration.""" + return { + "_name_or_path": "", + "add_cross_attention": False, + "architectures": ["ViTModel"], + "attention_probs_dropout_prob": 0.0, + "bad_words_ids": None, + "begin_suppress_tokens": None, + "bos_token_id": None, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": None, + "decoder_start_token_id": None, + "diversity_penalty": 0.0, + "do_sample": False, + "early_stopping": False, + "encoder_no_repeat_ngram_size": 0, + "encoder_stride": 16, + "eos_token_id": None, + "exponential_decay_length_penalty": None, + "finetuning_task": None, + "forced_bos_token_id": None, + "forced_eos_token_id": None, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_size": 768, + "id2label": {"0": "LABEL_0", "1": "LABEL_1"}, + "image_size": self.image_size, + "initializer_range": 0.02, + "intermediate_size": 3072, + "is_decoder": False, + "is_encoder_decoder": False, + "label2id": {"LABEL_0": 0, "LABEL_1": 1}, + "layer_norm_eps": 1e-12, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "vit", + "no_repeat_ngram_size": 0, + "num_attention_heads": 12, + "num_beam_groups": 1, + "num_beams": 1, + "num_channels": self.num_channels, + "num_hidden_layers": 12, + "num_return_sequences": 1, + "output_attentions": False, + "output_hidden_states": False, + "output_scores": False, + "pad_token_id": None, + "patch_size": self.patch_size, + "prefix": None, + "problem_type": None, + "pruned_heads": {}, + "qkv_bias": True, + "remove_invalid_values": False, + "repetition_penalty": 1.0, + "return_dict": True, + "return_dict_in_generate": False, + "sep_token_id": None, + "suppress_tokens": None, + "task_specific_params": None, + "temperature": 1.0, + "tf_legacy_loss": False, + "tie_encoder_decoder": False, + "tie_word_embeddings": True, + "tokenizer_class": None, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": None, + "torchscript": False, + "typical_p": 1.0, + "use_bfloat16": False, + } + + def _generateVocabFile(self, outputPath: Path) -> None: + """ + Generate the vocabulary JSON file. + + :param outputPath: Output path for the vocab.json file. + """ + vocab = self._getVocabulary() + + with open(outputPath, "w", encoding="utf-8") as f: + json.dump(vocab, f, indent=2, ensure_ascii=False) + + def _getVocabulary(self) -> dict[str, int]: + """ + Get the vocabulary mapping dictionary. + + :returns: dict[str, int]: Token to ID mapping. + """ + return { + "<|endoftext|>": 50256, + "<|pad|>": 50257, + "a": 0, + "an": 1, + "the": 2, + "free": 3, + "or": 4, + "but": 5, + "in": 6, + "on": 7, + "at": 8, + "to": 9, + "and": 10, + "of": 11, + "with": 12, + "by": 13, + "man": 14, + "for": 15, + "desk": 16, + "people": 17, + "visual": 18, + "children": 19, + "software": 20, + "girl": 21, + "dog": 22, + "desktop": 23, + "car": 24, + "truck": 25, + "bus": 26, + "bike": 27, + "non-visual": 28, + "NVDA": 29, + "plane": 30, + "boat": 31, + "house": 32, + "access": 33, + "flower": 35, + "microsoft": 36, + "sky": 37, + "cloud": 38, + "sun": 39, + "moon": 40, + "water": 41, + "river": 42, + "ocean": 43, + "red": 44, + "blue": 45, + "reader": 46, + "yellow": 47, + "black": 48, + "white": 49, + "brown": 50, + "orange": 51, + "purple": 52, + "pink": 53, + "!": 54, + "small": 55, + "tall": 56, + "short": 57, + "old": 58, + "young": 59, + "beautiful": 61, + "ugly": 62, + "good": 63, + "bad": 64, + "sitting": 65, + "standing": 66, + "walking": 67, + "running": 68, + "screen": 69, + "drinking": 70, + "playing": 71, + "working": 72, + "is": 73, + "open": 74, + "was": 75, + "were": 76, + "has": 77, + "Best": 78, + "helping": 79, + "will": 80, + "would": 81, + "could": 82, + "should": 83, + "very": 84, + "quite": 85, + "really": 86, + "too": 87, + "also": 88, + "source": 89, + "only": 90, + "even": 91, + "still": 92, + "already": 93, + "windows": 96, + } diff --git a/tests/system/nvdaSettingsFiles/standard-doLoadMockModel.ini b/tests/system/nvdaSettingsFiles/standard-doLoadMockModel.ini new file mode 100644 index 00000000000..eff6d77689c --- /dev/null +++ b/tests/system/nvdaSettingsFiles/standard-doLoadMockModel.ini @@ -0,0 +1,20 @@ +schemaVersion = 2 +[general] + language = en + showWelcomeDialogAtStartup = False +[update] + askedAllowUsageStats = True + autoCheck = False + startupNotification = False + allowUsageStats = False +[speech] + synth = speechSpySynthDriver + unicodeNormalization = DISABLED +[development] + enableScratchpadDir = True +[virtualBuffers] + autoSayAllOnPageLoad = False + passThroughAudioIndication = False +[automatedImageDescriptions] + enable = True + defaultModel = mock/vit-gpt2-image-captioning diff --git a/tests/system/robot/automatedImageDescriptions.py b/tests/system/robot/automatedImageDescriptions.py new file mode 100644 index 00000000000..bfbd6b31a57 --- /dev/null +++ b/tests/system/robot/automatedImageDescriptions.py @@ -0,0 +1,43 @@ +# A part of NonVisual Desktop Access (NVDA) +# Copyright (C) 2025 NV Access Limited, Tianze +# This file may be used under the terms of the GNU General Public License, version 2 or later. +# For more details see: https://www.gnu.org/licenses/gpl-2.0.html + +"""Logic for automatedImageDescriptions tests.""" + +import os +import pathlib + +from ChromeLib import ChromeLib as _ChromeLib +from SystemTestSpy import ( + _getLib, +) +import NvdaLib as _nvdaLib + +_chrome: _ChromeLib = _getLib("ChromeLib") + + +def NVDA_Caption(): + spy = _nvdaLib.getSpyLib() + iconPath = os.path.join( + _nvdaLib._locations.repoRoot, + "source", + "images", + "nvda.ico", + ) + url = pathlib.Path(iconPath).as_uri() + + _chrome.prepareChrome( + f""" +