diff --git a/docs/guides/grpo-audio-visual.md b/docs/guides/grpo-audio-visual.md new file mode 100644 index 0000000000..9cf09198ff --- /dev/null +++ b/docs/guides/grpo-audio-visual.md @@ -0,0 +1,86 @@ +# Audio-Visual GRPO with Qwen2.5-Omni-7B + +This guide explains how to use NeMo RL to train [Qwen2.5-Omni-7B](https://huggingface.co/Qwen/Qwen2.5-Omni-7B) with GRPO on the [PhilipC/IntentTrain](https://huggingface.co/datasets/PhilipC/IntentTrain) audio-visual intent-recognition dataset and evaluate on [Daily-Omni](https://huggingface.co/datasets/liarliar/Daily-Omni), following the dataset structure used in [HumanOmniV2](https://arxiv.org/abs/2506.21277). + +Each training sample feeds the Qwen2.5-Omni processor both the video stream (8 frames) and the audio track decoded from the same file at 16 kHz mono. Audio and video flow as two **independent multimodal items** per prompt: the dataset emits `{type: video}` + `{type: audio}` content items, the Qwen2.5-Omni chat template renders both `<|VIDEO|>` and `<|AUDIO|>` placeholders, and vLLM rollouts populate `multi_modal_data["video"]` and `multi_modal_data["audio"]` from the same sample. + +## 1. Train the Model + +Run GRPO training with the provided config: + +``` +uv run examples/run_vlm_grpo.py --config examples/configs/intent_grpo_7B_megatron.yaml +``` + +Config: `examples/configs/intent_grpo_7B_megatron.yaml` + +Key hyperparameters: + +| Parameter | Value | +| --- | --- | +| Model | Qwen2.5-Omni-7B | +| Train dataset | PhilipC/IntentTrain (problem_type = "multiple choice") | +| Validation dataset | PhilipC/IntentBench (problem_type = "multiple choice") | +| Modalities per prompt | video (8 frames, `<\|VIDEO\|>` placeholder) + audio (16 kHz mono, `<\|AUDIO\|>` placeholder) — independent multimodal items, no `use_audio_in_video` alignment | +| GPUs | 8 x 1 node, Megatron backend, `tensor_model_parallel_size=2` (data parallel = 4) | +| Learning rate | 1e-6 | +| KL penalty | 0.01 | +| Generations per prompt | 8 | +| Prompts per step | 32 | +| Train global / micro batch | 32 / 1 | +| Max steps | 1000 | +| Save period | 20 | +| Reward | format (0.2) + exact_alnum (0.8) | + +The dataset class downloads `PhilipC/IntentTrain` and `PhilipC/IntentBench` via `huggingface_hub.snapshot_download` and extracts each `videos.zip` once into the corresponding HuggingFace cache directory. Re-instantiating the dataset on a machine that already has the archives extracted is a no-op. + +Only `problem_type == "multiple choice"` samples are used. The allow-list is configurable through `data.train.allowed_problem_types` and `data.validation.allowed_problem_types` if you want to extend scope (for example, to `emer_ov_mc`); doing so requires picking an answer-correctness reward that handles those answer formats. + +### 7B training notes + +- **8 video frames** keep the prompt around ~4.5k tokens (8×360 video + ~1.5k audio + text), under `max_total_sequence_length=8192`, and roughly halve the training-forward activation memory versus 16 frames. Do **not** switch to fps-based sampling — at fps=2 the clips expand to ~43k video tokens, blow past the token budget, and `vlm_hf_data_processor` then empties the multimodal items and sets `loss_multiplier=0`. +- **`activation_checkpointing: true` + `gpu_memory_utilization: 0.4`** keep the Megatron forward inside the memory vLLM leaves resident after sleep mode. If `tensor_model_parallel_size=2` OOMs, fall back to `tensor_model_parallel_size=4` (proven to run at 8 frames). +- If `loss_multiplier` is logged at 0 for many samples, the multimodal prompt is exceeding `max_total_sequence_length`; bump it until validation samples consistently produce non-zero loss. +- Set `HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1` once `Qwen/Qwen2.5-Omni-7B`, `PhilipC/IntentTrain`, and `PhilipC/IntentBench` are pre-fetched, so Megatron's tokenizer worker doesn't hit the network. + +## 2. Convert Checkpoint (Megatron to HF) + +Checkpoints are saved under `results/intent_grpo_7B_megatron` (`checkpointing.checkpoint_dir`), one every `save_period=20` steps. Convert a checkpoint from Megatron to Hugging Face format before evaluating: + +``` +uv run --extra mcore python examples/converters/convert_megatron_to_hf.py \ + --config results/intent_grpo_7B_megatron/step_43/config.yaml \ + --megatron-ckpt-path results/intent_grpo_7B_megatron/step_43/policy/weights/iter_0000000 \ + --hf-ckpt-path results/intent_grpo_7B_megatron/step_43/hf --no-strict +``` + +Replace the step number with the checkpoint you want to evaluate. `--no-strict` is expected here: only the Qwen2.5-Omni *thinker* is trained, so the talker tensors are reported as "not written". The `--extra mcore` flag is required for the Megatron converter. + +## 3. Evaluate + +In-training validation uses IntentBench as the validation set, so `val_period`, `val_batch_size`, and `max_val_samples` from the config drive evaluation cadence. + +For a standalone benchmark, decode the converted HF checkpoint on [Daily-Omni](https://huggingface.co/datasets/liarliar/Daily-Omni) (1197 audio-visual multiple-choice questions) with `examples/run_eval.py`: + +``` +uv run examples/run_eval.py --config examples/configs/evals/daily_omni.yaml \ + generation.model_name=results/intent_grpo_7B_megatron/step_43/hf +``` + +The eval config (`examples/configs/evals/daily_omni.yaml`) feeds audio + video (32 frames — eval has no training-forward memory pressure, so it samples more densely than training), uses the same think+answer prompt as training, and scores with `exact_alnum` (case-insensitive exact match on the `` content). + +## 4. Results + +Daily-Omni accuracy (1197 questions, greedy decoding) for the base Qwen2.5-Omni-7B versus the GRPO-trained checkpoint: + +| Question type | Base | After GRPO | +| --- | --- | --- | +| **Overall** | **0.498** | **0.590** | +| AV Event Alignment | 0.353 | 0.450 | +| Comparative | 0.618 | 0.725 | +| Context understanding | 0.446 | 0.534 | +| Event Sequence | 0.395 | 0.490 | +| Inference | 0.714 | 0.760 | +| Reasoning | 0.651 | 0.766 | + +GRPO lifts overall Daily-Omni accuracy by ~9 points, with gains across every question category. The largest relative gains are on the reasoning-style questions. diff --git a/docs/index.md b/docs/index.md index dc7a586928..1bb776b86c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -121,6 +121,13 @@ Configure offline and online Eagle3 draft-model workflows to accelerate rollout Train Qwen2.5-Omni-3B with GRPO on AVQA and evaluate on MMAU, following the R1-AQA approach. ::: +:::{grid-item-card} {octicon}`device-camera-video` Audio+Video Intent GRPO +:link: guides/grpo-audio-visual +:link-type: doc + +Train Qwen2.5-Omni-7B with GRPO on PhilipC/IntentTrain (audio-visual intent recognition) and evaluate on Daily-Omni, following HumanOmniV2's joint audio+video setup. +::: + :::{grid-item-card} {octicon}`plus-circle` Adding New Models :link: adding-new-models :link-type: doc @@ -259,6 +266,7 @@ guides/ppo.md guides/grpo-deepscaler.md guides/grpo-sliding-puzzle.md guides/grpo-audio.md +guides/grpo-audio-visual.md guides/rm.md guides/environments.md guides/eval.md diff --git a/examples/configs/evals/daily_omni.yaml b/examples/configs/evals/daily_omni.yaml new file mode 100644 index 0000000000..53d465a5a1 --- /dev/null +++ b/examples/configs/evals/daily_omni.yaml @@ -0,0 +1,82 @@ +eval: + metric: "pass@k" + num_tests_per_prompt: 1 + seed: 42 + k_value: 1 + save_path: results/daily_omni_decode.json + +generation: + backend: "vllm" + max_new_tokens: 2048 + temperature: 0.0 + top_p: 1.0 + top_k: -1 + num_prompts_per_step: -1 + model_name: "Qwen/Qwen2.5-Omni-3B" + stop_token_ids: null + stop_strings: null + vllm_cfg: + async_engine: false + precision: "bfloat16" + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + # 0.9 -> 0.5: with 32 video frames + audio, the Qwen2.5-Omni vision/audio + # encoder forward needs a large chunk of *transient activation* memory that + # lives outside vLLM's KV-cache budget. At 0.9 the KV cache claims almost + # all VRAM (56+ GiB) and the first multimodal forward OOM-crashes the vLLM + # workers (hard EOF, no graceful torch OOM). 0.5 leaves ample headroom; KV + # cache is still ~1M tokens, far more than eval needs. + gpu_memory_utilization: 0.5 + # Bumped from 16000 to fit 32 video frames + the 16 kHz audio track + # without truncating the multimodal prompt (truncation silently masks + # samples out and collapses their reward to 0). + max_model_len: 32000 + enforce_eager: False + skip_tokenizer_init: False + limit_mm_per_prompt: + video: 1 + audio: 1 + vllm_kwargs: + # Disable mm processor cache to avoid vLLM cache eviction during eval. + mm_processor_cache_gb: 0 + # Cap concurrent sequences so the Qwen2.5-Omni vision/audio encoder only + # processes a few clips per step. With audio + 32 video frames, vLLM + # otherwise batches ~66 clips into one encoder forward and OOM-crashes the + # workers (kv_cache_usage was ~2% at crash -> it is encoder *activation* + # memory, not KV cache). 8 keeps the encoder batch small; eval throughput + # is not a concern. + max_num_seqs: 8 + colocated: + enabled: true + resources: + gpus_per_node: null + num_nodes: null + +tokenizer: + name: ${generation.model_name} + chat_template: "default" + chat_template_kwargs: null + video: + # 16 -> 32 frames: 60s clips at 16 frames is ~1 frame / 3.75s, too sparse + # for fine-grained temporal (Event Sequence) questions. + num_frames: 32 + +data: + max_input_seq_length: ${generation.vllm_cfg.max_model_len} + prompt_file: examples/prompts/daily_omni.txt + system_prompt_file: null + dataset_name: "daily-omni" + split: "train" + env_name: vlm + +env: + vlm: + num_workers: 8 + reward_functions: + - name: exact_alnum + weight: 1.0 + +cluster: + gpus_per_node: 1 + num_nodes: 1 diff --git a/examples/configs/intent_grpo_7B_megatron.yaml b/examples/configs/intent_grpo_7B_megatron.yaml new file mode 100644 index 0000000000..3b5a7b8789 --- /dev/null +++ b/examples/configs/intent_grpo_7B_megatron.yaml @@ -0,0 +1,168 @@ +# Intent (audio+video) GRPO 7B Megatron configuration. +# +# Trains Qwen/Qwen2.5-Omni-7B with GRPO on PhilipC/IntentTrain (intent +# recognition over short MER24 / social_iq video clips with audio) and runs +# in-training validation on PhilipC/IntentBench. +# * Audio and video reach the model as two independent multimodal items +# per prompt: the dataset emits {type: video} + {type: audio}, the chat +# template renders <|VIDEO|> and <|AUDIO|> placeholders, and vLLM +# rollouts pass them as multi_modal_data["video"] / multi_modal_data["audio"]. +# use_audio_in_video=True / mm_processor_kwargs are NOT used because the +# installed transformers + vLLM Qwen2.5-Omni stack rejected that path. +# * Only problem_type == "multiple choice" samples are used; rewards reuse +# the audio recipe's format + exact_alnum. +# +# 7B requires more aggressive sharding than 3B to fit on 80 GB H100s alongside +# vLLM rollout memory: +# * tensor_model_parallel_size: 2 -> model state sharded across 2 ranks, +# data parallel size = gpus_per_node / TP = 4 with 8 GPUs. +# * per-forward batch must be exactly 1 sample/rank (train_micro_batch_size=1, +# logprob_batch_size=1), else the Qwen2.5-Omni get_rope_index path crashes +# with "IndexError: index 1 is out of bounds for dimension 0 with size 1". +# * num_frames 8 (vs the 3B recipe's 16) to roughly halve the prompt length +# and the training-forward activation memory. +# * activation_checkpointing on, vllm gpu_memory_utilization 0.4 to leave +# headroom for the Megatron forward. +# +# Inherits directly from grpo_math_1B_megatron.yaml (the same base the 3B +# recipe uses) and overrides intent-specific + 7B-specific settings. +defaults: "grpo_math_1B_megatron.yaml" + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 8 + max_num_steps: 1000 + val_at_start: false + max_val_samples: 256 + val_batch_size: 32 + +checkpointing: + enabled: true + checkpoint_dir: results/intent_grpo_7B_megatron + keep_top_k: 10 + # save_period 20: a 1-epoch (~85-step) 7B run is slow (~6 min/step) and + # previously hit the Slurm time limit at ~step 30 with checkpoints/ still + # EMPTY. 20 lands a checkpoint at steps 20/40/60/80. checkpoint_must_save_by + # additionally forces a save once 3h45m of wall-clock have elapsed so + # progress survives the job time limit (format DD:HH:MM:SS). + save_period: 20 + checkpoint_must_save_by: "00:03:45:00" + +policy: + model_name: Qwen/Qwen2.5-Omni-7B + # PER-FORWARD batch must be exactly 1 sample/rank, else the Qwen2.5-Omni + # get_rope_index path crashes with "IndexError: index 1 is out of bounds for + # dimension 0 with size 1" (input_ids batch > attention_mask batch). That is + # controlled by train_micro_batch_size=1 (train forward) and + # logprob_batch_size=1 (log-prob forward). train_global_batch_size=32 only + # sets gradient accumulation and must stay divisible by micro x DP + # (32 % (1 x DP=4) == 0). + train_global_batch_size: 32 + train_micro_batch_size: 1 + generation_batch_size: 32 + logprob_batch_size: 1 + # Audio + video produces materially more tokens than the audio-only recipe; + # this budget keeps loss_multiplier > 0 with headroom. The video frame count + # (tokenizer.video.num_frames) is the dominant lever on prompt length -- do + # not raise it (or switch to fps) without raising this too. + max_total_sequence_length: 8192 + + tokenizer: + video: + # 7B: 8 frames (vs the 3B recipe's 16) to roughly halve the prompt length + # (~7.3k -> ~4.5k tokens: 8x360 video + ~1.5k audio + text) and thus the + # training-forward activation memory. NOTE: stopgap -- the proper fix + # (matching HumanOmniV2, which only trains the LLM) is to FREEZE the + # vision/audio encoders, which needs a code hook (no YAML knob exists). + # DO NOT switch to fps-based sampling: fps=2 expands the clips to ~43k + # video tokens, blows past max_total_sequence_length / vLLM max_model_len, + # and vlm_hf_data_processor then empties the multimodal items + # (loss_multiplier=0). fps and num_frames are mutually exclusive. + num_frames: 8 + + sequence_packing: + enabled: false + + generation: + max_new_tokens: 1024 + vllm_cfg: + # Audio/multimodal models require tokenizer to be initialized before generation + skip_tokenizer_init: False + # 7B model state crowds the GPU; lower vLLM cache budget so Megatron has + # room for activations during the training-time forward pass. + gpu_memory_utilization: 0.4 + limit_mm_per_prompt: + video: 1 + audio: 1 + vllm_kwargs: + # Disable mm processor cache to avoid vLLM cache eviction assertion error during validation. + mm_processor_cache_gb: 0 + + megatron_cfg: + converter_type: Qwen2_5OmniForConditionalGeneration + apply_rope_fusion: false + activation_checkpointing: true + # TP=2 (DP=4 on 8 GPUs) -- 2x the data-parallel throughput of TP=4. Valid + # TP values are 1/2/4 (num_attention_heads=28 must be divisible by TP; TP=8 + # fails). At num_frames=8 (~4.5k-token sequence) the logits/activation + # memory is ~40% smaller than at 16 frames, so TP=2 fits. If it OOMs, fall + # back to tensor_model_parallel_size=4 (proven to run at 8 frames). + tensor_model_parallel_size: 2 + optimizer: + lr: 1.0e-6 + min_lr: 1.0e-7 + scheduler: + lr_warmup_iters: 10 + lr_warmup_init: 1.0e-7 + distributed_data_parallel_config: + overlap_grad_reduce: false + +data: + num_workers: 0 + train: + dataset_name: intent-train + split: train + allowed_problem_types: + - "multiple choice" + validation: + dataset_name: intent-bench + split: validation + allowed_problem_types: + - "multiple choice" + default: + prompt_file: null + system_prompt_file: null + processor: "vlm_hf_data_processor" + env_name: "vlm" + +env: + vlm: + num_workers: 8 + # Strict two-signal reward (format + accuracy), same structure as the + # HumanOmniV2 reference. The IntentDataset prompt instructs the model to + # reason between and commit the answer between + # tags: + # * format -- rewards the ...... + # structure (does not gate correctness). + # * exact_alnum -- case-insensitive exact match on the content; + # returns 0 when the tag is missing, so the model + # must emit the wrapped form to earn the accuracy signal. + reward_functions: + - name: format + weight: 0.2 + - name: exact_alnum + weight: 0.8 + +logger: + wandb_enabled: true + tensorboard_enabled: true + monitor_gpus: false + wandb: + project: grpo-dev + name: intent-grpo-7b-megatron + swanlab: + project: grpo-dev + name: intent-grpo-7b-megatron + +cluster: + gpus_per_node: 8 diff --git a/examples/prompts/daily_omni.txt b/examples/prompts/daily_omni.txt new file mode 100644 index 0000000000..e5d1469e1f --- /dev/null +++ b/examples/prompts/daily_omni.txt @@ -0,0 +1 @@ +{} First reason briefly between tags, then output only the single option letter (e.g., A, B, C, D, ...) between tags. Format example: your reasoningA diff --git a/nemo_rl/data/__init__.py b/nemo_rl/data/__init__.py index 04a7e73ae4..c71cc328a0 100644 --- a/nemo_rl/data/__init__.py +++ b/nemo_rl/data/__init__.py @@ -177,6 +177,32 @@ class MMAUEvalDataConfig(TypedDict): env_name: NotRequired[str] +class DailyOmniEvalDataConfig(TypedDict): + """Config for the Daily-Omni audio-visual eval dataset. + + Mirrors the MMAU multimodal schema but with its own ``dataset_name`` literal + so the eval-config union resolves daily-omni unambiguously. Kept as a + ``TypedDict`` for consistency with the other (still v1) eval-data configs in + this union, whose consumers access the resolved config by key + (``config.data["dataset_name"]``). + + Fields: + max_input_seq_length: Max prompt length passed to the generation backend. + dataset_name: Must be ``"daily-omni"``. + split: HuggingFace split to load. + prompt_file: Optional prompt template path. + system_prompt_file: Optional system prompt path. + env_name: Reward/eval environment name (e.g. ``"vlm"``). + """ + + max_input_seq_length: int + dataset_name: Literal["daily-omni"] + split: NotRequired[str | None] + prompt_file: NotRequired[str | None] + system_prompt_file: NotRequired[str | None] + env_name: NotRequired[str] + + # Union type for all eval dataset configs EvalDataConfigType = Union[ MMLUEvalDataConfig, @@ -185,5 +211,6 @@ class MMAUEvalDataConfig(TypedDict): GPQAEvalDataConfig, MathEvalDataConfig, MMAUEvalDataConfig, + DailyOmniEvalDataConfig, LocalMathEvalDataConfig, ] diff --git a/nemo_rl/data/collate_fn.py b/nemo_rl/data/collate_fn.py index 6f4291aa43..86f91b247e 100644 --- a/nemo_rl/data/collate_fn.py +++ b/nemo_rl/data/collate_fn.py @@ -117,6 +117,7 @@ def eval_collate_fn(data_batch: list[DatumSpec]) -> BatchedDataDict[Any]: message_log = [datum_spec["message_log"] for datum_spec in data_batch] extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch] idx = [datum_spec["idx"] for datum_spec in data_batch] + task_names = [datum_spec.get("task_name", None) for datum_spec in data_batch] # Check if any of the data batch has vllm content (multimodal data) extra_args = {} @@ -132,11 +133,15 @@ def eval_collate_fn(data_batch: list[DatumSpec]) -> BatchedDataDict[Any]: extra_args["vllm_audios"] = [ datum_spec.get("vllm_audios", []) for datum_spec in data_batch ] + extra_args["vllm_videos"] = [ + datum_spec.get("vllm_videos", []) for datum_spec in data_batch + ] output: BatchedDataDict[Any] = BatchedDataDict( message_log=message_log, extra_env_info=extra_env_info, idx=idx, + task_name=task_names, **extra_args, ) return output diff --git a/nemo_rl/data/datasets/eval_datasets/__init__.py b/nemo_rl/data/datasets/eval_datasets/__init__.py index 296323efda..2243b37234 100644 --- a/nemo_rl/data/datasets/eval_datasets/__init__.py +++ b/nemo_rl/data/datasets/eval_datasets/__init__.py @@ -15,6 +15,7 @@ from typing import cast from nemo_rl.data.datasets.eval_datasets.aime import AIMEDataset, AIMEVariant +from nemo_rl.data.datasets.eval_datasets.daily_omni import DailyOmniEvalDataset from nemo_rl.data.datasets.eval_datasets.gpqa import GPQADataset from nemo_rl.data.datasets.eval_datasets.local_math_dataset import LocalMathDataset from nemo_rl.data.datasets.eval_datasets.math import MathDataset @@ -23,7 +24,7 @@ from nemo_rl.data.datasets.eval_datasets.mmlu_pro import MMLUProDataset # Dataset names that require multimodal (VLM) processing -MULTIMODAL_DATASETS = {"mmau", "TwinkStart/MMAU"} +MULTIMODAL_DATASETS = {"mmau", "TwinkStart/MMAU", "daily-omni"} def _is_multimodal_dataset(dataset_name): @@ -94,6 +95,14 @@ def load_eval_dataset(data_config): dataset_name="TwinkStart/MMAU", split=split, ) + # daily-omni + elif dataset_name == "daily-omni": + split = data_config.get("split", "train") + base_dataset = DailyOmniEvalDataset( + split=split, + prompt_file=data_config.get("prompt_file"), + system_prompt_file=data_config.get("system_prompt_file"), + ) # fall back to local dataset else: print(f"Loading dataset from {dataset_name}...") @@ -112,6 +121,7 @@ def load_eval_dataset(data_config): __all__ = [ "AIMEDataset", + "DailyOmniEvalDataset", "GPQADataset", "LocalMathDataset", "MathDataset", diff --git a/nemo_rl/data/datasets/eval_datasets/daily_omni.py b/nemo_rl/data/datasets/eval_datasets/daily_omni.py new file mode 100644 index 0000000000..e37968392a --- /dev/null +++ b/nemo_rl/data/datasets/eval_datasets/daily_omni.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Daily-Omni evaluation dataset wrapper.""" + +import re +from typing import Any, Optional + +from nemo_rl.data.datasets.response_datasets.daily_omni import DailyOmniDataset +from nemo_rl.data.interfaces import TaskDataSpec +from nemo_rl.data.processors import vlm_hf_data_processor + +# The training-side ``DailyOmniDataset.get_prompt`` ends with a hard +# "must contain only a single letter" instruction that overrides any later +# ```` formatting request. Strip it for eval so the prompt_file template +# can dictate output formatting without conflict. +_SINGLE_LETTER_LINE = re.compile( + r"\n+Your replies must contain only a single letter[^\n]*" +) + + +class DailyOmniEvalDataset: + """Daily-Omni evaluation dataset. + + Reuses the response-side ``DailyOmniDataset`` (HF snapshot, tar extraction, + qa.json load) and exposes the attributes that ``run_eval.py`` needs: + ``rekeyed_ds``, ``task_spec``, ``processor``, and ``preprocessor``. + + ``prompt_file`` / ``system_prompt_file`` are optional templates with a single + ``{}`` placeholder for the question text — used by ``vlm_hf_data_processor`` + to wrap the user message (e.g. to require `` `` formatting). + """ + + def __init__( + self, + split: str = "train", + prompt_file: Optional[str] = None, + system_prompt_file: Optional[str] = None, + ): + self._base = DailyOmniDataset(split=split) + self.rekeyed_ds = self._base.dataset + self.task_spec = TaskDataSpec( + task_name=self._base.task_name, + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = vlm_hf_data_processor + self.preprocessor = self._format_for_eval + + def _format_for_eval(self, data: dict[str, Any]) -> dict[str, Any]: + out = self._base.format_data(data) + # Content order is [video, audio, text]; locate the text item by type + # rather than a fixed index so it stays correct as media items change. + text_item = next( + item for item in out["messages"][0]["content"] if item["type"] == "text" + ) + text_item["text"] = _SINGLE_LETTER_LINE.sub("", text_item["text"]) + return out diff --git a/nemo_rl/data/datasets/response_datasets/__init__.py b/nemo_rl/data/datasets/response_datasets/__init__.py index a7cd82e57f..810c40a05a 100644 --- a/nemo_rl/data/datasets/response_datasets/__init__.py +++ b/nemo_rl/data/datasets/response_datasets/__init__.py @@ -30,6 +30,10 @@ from nemo_rl.data.datasets.response_datasets.geometry3k import Geometry3KDataset from nemo_rl.data.datasets.response_datasets.gsm8k import GSM8KDataset from nemo_rl.data.datasets.response_datasets.helpsteer3 import HelpSteer3Dataset +from nemo_rl.data.datasets.response_datasets.intent import ( + IntentBenchDataset, + IntentTrainDataset, +) from nemo_rl.data.datasets.response_datasets.nemogym_dataset import NemoGymDataset from nemo_rl.data.datasets.response_datasets.nemotron_cascade2_sft import ( NemotronCascade2SFTMathDataset, @@ -62,6 +66,8 @@ "GSM8K": GSM8KDataset, "geometry3k": Geometry3KDataset, "HelpSteer3": HelpSteer3Dataset, + "intent-train": IntentTrainDataset, + "intent-bench": IntentBenchDataset, "open_assistant": OasstDataset, "OpenMathInstruct-2": OpenMathInstruct2Dataset, "refcoco": RefCOCODataset, @@ -131,6 +137,8 @@ def load_response_dataset(data_config: ResponseDatasetConfig): "DeepScalerDataset", "Geometry3KDataset", "HelpSteer3Dataset", + "IntentBenchDataset", + "IntentTrainDataset", "NemoGymDataset", "NemotronCascade2SFTMathDataset", "OasstDataset", diff --git a/nemo_rl/data/datasets/response_datasets/daily_omni.py b/nemo_rl/data/datasets/response_datasets/daily_omni.py index b2307e337f..d5bdde54c5 100644 --- a/nemo_rl/data/datasets/response_datasets/daily_omni.py +++ b/nemo_rl/data/datasets/response_datasets/daily_omni.py @@ -15,6 +15,7 @@ import os from typing import Any +import numpy as np from huggingface_hub import snapshot_download from nemo_rl.data.datasets.raw_dataset import RawDataset @@ -24,6 +25,28 @@ ) +def _load_audio_16k_mono(path: str) -> np.ndarray: + """Decode an audio file as a 1-D float32 array at 16 kHz mono. + + Daily-Omni ships each clip's audio track as a sibling ``*_audio.wav`` next + to ``*_video.mp4``. We feed it as an independent ``{type: audio}`` content + item (mirroring the IntentTrain training path) so the Qwen2.5-Omni chat + template renders an ``<|AUDIO|>`` placeholder and vLLM populates + ``multi_modal_data["audio"]``. The benchmark is audio-visual, so video + frames alone leave audio-dependent questions unanswerable. Uses decord + (already a project dependency for video decoding) for the same 16 kHz mono + pipeline the training path uses. + """ + import decord + + reader = decord.AudioReader(path, sample_rate=16000, mono=True) + # Shape: (channels, T). With mono=True channels=1; squeeze to (T,). + audio = reader[:].asnumpy() + if audio.ndim > 1: + audio = audio[0] + return audio.astype(np.float32) + + class DailyOmniDataset(RawDataset): """Simple wrapper around the Daily-Omni dataset. @@ -116,20 +139,16 @@ def get_prompt(cls, data: dict[str, Any]) -> str: return prompt def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + video_dir = os.path.join(self.hf_cache_dir, "Videos", data["video_id"]) + video_path = os.path.join(video_dir, data["video_id"] + "_video.mp4") + audio_path = os.path.join(video_dir, data["video_id"] + "_audio.wav") + # Audio + video flow as two independent content items so the + # Qwen2.5-Omni chat template renders both <|VIDEO|> and <|AUDIO|> + # placeholders (Daily-Omni is an audio-visual benchmark). user_content = [ - { - "type": "video", - "video": os.path.join( - self.hf_cache_dir, - "Videos", - data["video_id"], - data["video_id"] + "_video.mp4", - ), - }, - { - "type": "text", - "text": self.get_prompt(data), - }, + {"type": "video", "video": video_path}, + {"type": "audio", "audio": _load_audio_16k_mono(audio_path)}, + {"type": "text", "text": self.get_prompt(data)}, ] return { "messages": [ diff --git a/nemo_rl/data/datasets/response_datasets/intent.py b/nemo_rl/data/datasets/response_datasets/intent.py new file mode 100644 index 0000000000..a574d6dd8f --- /dev/null +++ b/nemo_rl/data/datasets/response_datasets/intent.py @@ -0,0 +1,381 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IntentDataset: HumanOmniV2 IntentTrain / IntentBench loader for GRPO. + +Loads the PhilipC/IntentTrain (training) or PhilipC/IntentBench (validation) +datasets that ship as a JSON manifest plus a ``videos.zip`` archive on +HuggingFace, filters samples to the configured ``problem_type`` allow-list, and +emits OpenAI-style messages whose user content carries both a video reference +and the audio track extracted from that same video. Audio and video flow as +two independent ``{type:audio}`` / ``{type:video}`` content items so the +Qwen2.5-Omni chat template renders both ``<|VIDEO|>`` and ``<|AUDIO|>`` +placeholders into the prompt -- vLLM's multimodal prompt replacement on the +rollout side requires those placeholders to exist before it accepts matching +``mm_items``. The ``use_audio_in_video=True`` time-alignment hint is NOT +threaded through here because the installed transformers + vLLM stack +rejected that path during Round 1 testing (see BitLesson +BL-20260428-omni-use-audio-in-video). +""" + +import ast +import json +import logging +import os +import zipfile +from typing import Any + +import numpy as np +from huggingface_hub import snapshot_download + +from nemo_rl.data.datasets.raw_dataset import RawDataset +from nemo_rl.data.datasets.utils import get_huggingface_cache_path + +logger = logging.getLogger(__name__) + +# Per-problem-type instruction appended to the question. The wording asks +# the model to first think between ... tags and then commit +# the final answer between ... tags so both NeMo-RL reward +# functions (format_reward checks for + ; exact_alnum +# extracts content from ) can score the response. Without the +# explicit "" instruction the base Qwen2.5-Omni-3B emits a bare +# letter (e.g. "B") and both rewards collapse to 0. +_TYPE_TEMPLATE = { + "multiple choice": ( + " First reason briefly between tags, then output " + "only the single option letter (e.g., A, B, C, D, ...) between " + " tags. Format example: " + "your reasoningA" + ), + "emer_ov_mc": ( + " First reason briefly between tags, then output " + "the single or multi-letter answer (e.g., A for single, A,E for " + "multiple) between tags. Format example: " + "your reasoningA,E" + ), + "numerical": ( + " First reason briefly between tags, then output " + "the numerical value (e.g., 42 or 3.14) between " + "tags. Format example: your reasoning42" + ), + "judge": ( + " First reason briefly between tags, then answer " + "Yes or No between tags. Format example: " + "your reasoningYes" + ), + "free-form": ( + " First reason briefly between tags, then provide " + "your final text answer between tags. Format " + "example: your reasoningyour answer" + ), +} + + +def _format_options(options: Any) -> str: + """Render a record's multiple-choice options into the prompt text. + + IntentTrain/IntentBench manifests store ``options`` as a list of strings + like ``["A.first choice", "B.second choice", ...]`` (occasionally as a + string repr of that list). These MUST be appended to the prompt: without + them the model only sees the question stem and has to emit a bare option + letter blind (capping accuracy near chance). Mirrors HumanOmniV2's prompt + construction. Returns an empty string when no options are present. + """ + if not options: + return "" + if isinstance(options, str): + try: + options = ast.literal_eval(options) + except (ValueError, SyntaxError): + return f" Options:\n{options}" + if isinstance(options, (list, tuple)): + return " Options:\n" + "\n".join(str(o) for o in options) + return f" Options:\n{options}" + + +# Per-split HF repo + manifest filenames for the HumanOmniV2 IntentTrain / +# IntentBench releases. Each split downloads a videos.zip and one or more JSON +# manifests; manifest entries point at relative paths inside the extracted +# archive. +_SPLIT_CONFIG = { + "train": { + "repo_id": "PhilipC/IntentTrain", + "manifests": ["emer_rewrite.json", "social_iq_v2_rewrite.json"], + "task_name": "intent-train", + }, + "validation": { + "repo_id": "PhilipC/IntentBench", + "manifests": ["qa.json"], + "task_name": "intent-bench", + }, +} + +_EXTRACTION_SENTINEL = ".intent_videos_extracted" + + +def _extract_videos_zip_once(snapshot_dir: str) -> str: + """Idempotently extract ``videos.zip`` inside ``snapshot_dir``. + + Returns the directory the archive was extracted into. A sentinel file is + written after a successful extraction so subsequent constructions skip + re-extraction. + """ + archive = os.path.join(snapshot_dir, "videos.zip") + if not os.path.isfile(archive): + raise FileNotFoundError( + f"videos.zip not found in HuggingFace snapshot at {snapshot_dir}. " + "Was the dataset downloaded correctly?" + ) + + sentinel = os.path.join(snapshot_dir, _EXTRACTION_SENTINEL) + if os.path.isfile(sentinel): + return snapshot_dir + + with zipfile.ZipFile(archive, "r") as zf: + zf.extractall(snapshot_dir) + + with open(sentinel, "w", encoding="utf-8") as f: + f.write("ok\n") + return snapshot_dir + + +def _resolve_video_path(snapshot_dir: str, relpath: str) -> str | None: + """Resolve a manifest's relative video path to an absolute file on disk. + + The IntentTrain/IntentBench archives extract their contents either directly + under the snapshot directory or under a ``videos/`` subdirectory. Try both + and return the first path that exists, or ``None`` if neither does. + """ + candidate = os.path.join(snapshot_dir, relpath) + if os.path.isfile(candidate): + return candidate + candidate = os.path.join(snapshot_dir, "videos", relpath) + if os.path.isfile(candidate): + return candidate + return None + + +def _load_audio_from_video(video_path: str, sampling_rate: int = 16000) -> np.ndarray: + """Decode the audio track of a video file as a 1-D float32 array. + + Uses decord's ``AudioReader`` because it's already a project dependency for + video decoding. Raises ``RuntimeError`` if the video has no decodable audio + track so callers can drop or skip the sample. + """ + import decord + + try: + reader = decord.AudioReader(video_path, sample_rate=sampling_rate, mono=True) + # Shape: (channels, T). With mono=True channels=1; squeeze to (T,). + audio = reader[:].asnumpy() + if audio.ndim > 1: + audio = audio[0] + return audio.astype(np.float32) + except Exception as e: # decord raises a variety of errors for missing audio + raise RuntimeError(f"Failed to decode audio from {video_path}: {e}") from e + + +def _read_manifest(snapshot_dir: str, manifest_filename: str) -> list[dict[str, Any]]: + manifest_path = os.path.join(snapshot_dir, manifest_filename) + if not os.path.isfile(manifest_path): + raise FileNotFoundError( + f"Manifest {manifest_filename} not found in HF snapshot at " + f"{snapshot_dir}. Available files: {sorted(os.listdir(snapshot_dir))}" + ) + with open(manifest_path, "r", encoding="utf-8") as f: + if manifest_filename.endswith(".jsonl"): + return [json.loads(line) for line in f if line.strip()] + return json.load(f) + + +class IntentDataset(RawDataset): + """HumanOmniV2 IntentTrain / IntentBench loader for VLM GRPO. + + Each sample emits both a video file path and a 16 kHz mono audio array + decoded from that same file as two independent content items + (``{type:video}`` and ``{type:audio}``) plus a text prompt. The + Qwen2.5-Omni processor and vLLM rollout both treat the two streams as + independent multimodal sources; the explicit time-alignment via + ``use_audio_in_video=True`` is intentionally not used in v1 because the + installed transformers + vLLM stack rejected that path. Samples whose + ``problem_type`` is not in ``allowed_problem_types`` are dropped before + iteration. + + Args: + split: ``"train"`` (PhilipC/IntentTrain) or ``"validation"`` + (PhilipC/IntentBench). + allowed_problem_types: List of ``problem_type`` values to retain. + Defaults to ``["multiple choice"]`` per DEC-2. + max_samples: Optional cap on the number of samples after filtering. + Useful for smoke runs. + """ + + def __init__( + self, + split: str = "train", + allowed_problem_types: list[str] | None = None, + max_samples: int | None = None, + **kwargs: Any, + ) -> None: + if split not in _SPLIT_CONFIG: + raise ValueError( + f"Invalid split: {split!r}. Supported: {sorted(_SPLIT_CONFIG.keys())}." + ) + self.split = split + self._cfg = _SPLIT_CONFIG[split] + self.task_name = self._cfg["task_name"] + self.allowed_problem_types = list( + allowed_problem_types + if allowed_problem_types is not None + else ["multiple choice"] + ) + + self.snapshot_dir = self._download_and_extract() + + records = self._load_records() + records = self._filter_records(records) + if max_samples is not None: + records = records[:max_samples] + if not records: + raise ValueError( + f"IntentDataset({split=}) yielded 0 samples after filtering by " + f"allowed_problem_types={self.allowed_problem_types}. " + "Check the manifest contents and filter list." + ) + + from datasets import Dataset + + self.dataset = Dataset.from_list(records) + self.dataset = self.dataset.add_column( + "task_name", [self.task_name] * len(self.dataset) + ) + self.preprocessor = self.format_data + self.val_dataset = None + + def _download_and_extract(self) -> str: + """Download the HF dataset snapshot and extract ``videos.zip`` once.""" + repo_id = self._cfg["repo_id"] + cache_dir = get_huggingface_cache_path(repo_id) + if not cache_dir: + cache_dir = snapshot_download(repo_id=repo_id, repo_type="dataset") + if not cache_dir: + raise ValueError(f"Cannot download {repo_id}.") + return _extract_videos_zip_once(cache_dir) + + def _load_records(self) -> list[dict[str, Any]]: + records: list[dict[str, Any]] = [] + for manifest in self._cfg["manifests"]: + try: + manifest_records = _read_manifest(self.snapshot_dir, manifest) + except FileNotFoundError: + if len(self._cfg["manifests"]) == 1: + raise + logger.warning( + "Manifest %s missing in snapshot %s; skipping", + manifest, + self.snapshot_dir, + ) + continue + records.extend(manifest_records) + if not records: + raise ValueError( + f"No manifest entries loaded for {self._cfg['repo_id']}. " + f"Expected one of: {self._cfg['manifests']}." + ) + return records + + def _filter_records(self, records: list[dict[str, Any]]) -> list[dict[str, Any]]: + allowed = set(self.allowed_problem_types) + filtered: list[dict[str, Any]] = [] + for record in records: + problem_type = record.get("problem_type") + if problem_type not in allowed: + continue + data_type = record.get("data_type", "video") + if data_type != "video": + # Mixed modalities (e.g. image-only entries from + # Video-R1_rewrite.json) are out of scope; the recipe is + # video-first per DEC-1 / DEC-2. + continue + relpath = record.get("video") or record.get("path") + if not isinstance(relpath, str): + continue + local_path = _resolve_video_path(self.snapshot_dir, relpath) + if local_path is None: + logger.warning( + "Skipping manifest entry: video not found for relpath=%s", + relpath, + ) + continue + filtered.append( + { + "problem": record.get("problem", ""), + "problem_type": problem_type, + "answer": record.get("answer", ""), + "options": record.get("options"), + "video_path": local_path, + } + ) + return filtered + + def format_data(self, data: dict[str, Any]) -> dict[str, Any]: + """Format a manifest record into NeMo-RL OpenAI-style messages. + + Each yielded sample carries the video file path AND the audio track + decoded from that same file at 16 kHz mono. Both arrive as + independent ``{type: video}`` / ``{type: audio}`` content items so + the Qwen2.5-Omni chat template renders both ``<|VIDEO|>`` and + ``<|AUDIO|>`` placeholders in the prompt; vLLM's multimodal prompt + replacement on the rollout side requires those placeholders to exist + in the prompt before it will accept matching ``mm_items``. + + We deliberately do NOT pass ``use_audio_in_video=True`` to the + processor in v1: that flag would entangle the audio and video + placeholder accounting in ways the current installed transformers + + vLLM stack does not handle (see Round 1 BitLesson). The model + still receives both modalities; the only thing missing is the + explicit time alignment hint. + """ + instruction = _TYPE_TEMPLATE.get(data["problem_type"], "") + options_text = _format_options(data.get("options")) + prompt_text = f"{data['problem']}{options_text}{instruction}" + audio_array = _load_audio_from_video(data["video_path"]) + user_content = [ + {"type": "video", "video": data["video_path"]}, + {"type": "audio", "audio": audio_array}, + {"type": "text", "text": prompt_text}, + ] + return { + "messages": [ + {"role": "user", "content": user_content}, + {"role": "assistant", "content": str(data["answer"])}, + ], + "task_name": self.task_name, + } + + +class IntentTrainDataset(IntentDataset): + """Convenience wrapper that pins ``split="train"`` for IntentTrain.""" + + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("split", "train") + super().__init__(**kwargs) + + +class IntentBenchDataset(IntentDataset): + """Convenience wrapper that pins ``split="validation"`` for IntentBench.""" + + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("split", "validation") + super().__init__(**kwargs) diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py index 13e24c6add..82082962bb 100644 --- a/nemo_rl/data/processors.py +++ b/nemo_rl/data/processors.py @@ -461,6 +461,7 @@ def vlm_hf_data_processor( from nemo_rl.data.multimodal_utils import ( PackedTensor, get_dim_to_pack_along, + get_multimodal_default_settings_from_processor, get_multimodal_keys_from_processor, resolve_to_image, ) @@ -478,6 +479,10 @@ def vlm_hf_data_processor( pass # AudioMCQ data is already formatted by AudioMCQDataset.format_data elif datum_dict["task_name"] == "mmau": pass # MMAU data is already formatted by MMAUDataset.format_data + elif datum_dict["task_name"] == "daily-omni": + pass # Daily-Omni data is already formatted by DailyOmniDataset.format_data + elif datum_dict["task_name"] in ("intent-train", "intent-bench"): + pass # IntentDataset.format_data already produces the message structure else: raise ValueError(f"No data processor for task {datum_dict['task_name']}") @@ -493,6 +498,8 @@ def vlm_hf_data_processor( # images = [] audios = [] + videos = [] + load_video_kwargs: dict[str, Any] = {} if isinstance(problem, list): for content in problem: # for image, video, audio, just append it @@ -515,6 +522,21 @@ def vlm_hf_data_processor( audios.append( (content["audio"], processor.feature_extractor.sampling_rate) ) + elif content["type"] == "video": + from transformers.video_utils import load_video + + if not load_video_kwargs: + load_video_kwargs = get_multimodal_default_settings_from_processor( + processor + ).get("video", {}) + video_value = content["video"] + if isinstance(video_value, str): + video_value = load_video( + video_value, backend="decord", **load_video_kwargs + )[0] + # Replace path with loaded frames so apply_chat_template can consume it + user_message["content"].append({"type": "video", "video": video_value}) + videos.append(video_value) else: raise ValueError(f"Unsupported content type: {content['type']}") else: @@ -576,6 +598,7 @@ def vlm_hf_data_processor( "vllm_content": None, "vllm_images": [], "vllm_audios": [], + "vllm_videos": [], } # make smaller and mask out @@ -593,6 +616,7 @@ def vlm_hf_data_processor( "vllm_content": string_formatted_dialog, "vllm_images": images, "vllm_audios": audios, + "vllm_videos": videos, } output: DatumSpec = { diff --git a/nemo_rl/evals/eval.py b/nemo_rl/evals/eval.py index 28c394ef25..8a5d7ea097 100644 --- a/nemo_rl/evals/eval.py +++ b/nemo_rl/evals/eval.py @@ -343,6 +343,11 @@ async def _run_env_eval_impl( multi_modal_data["image"] = ( images[i][0] if len(images[i]) == 1 else images[i] ) + videos = batch.get("vllm_videos", None) + if videos is not None and len(videos[i]) > 0: + multi_modal_data["video"] = ( + videos[i][0] if len(videos[i]) == 1 else videos[i] + ) if multi_modal_data: prompt_dict["multi_modal_data"] = multi_modal_data prompts.append(prompt_dict) diff --git a/nemo_rl/models/generation/vllm/utils.py b/nemo_rl/models/generation/vllm/utils.py index 349d36fabf..fe49ba3b75 100644 --- a/nemo_rl/models/generation/vllm/utils.py +++ b/nemo_rl/models/generation/vllm/utils.py @@ -70,7 +70,7 @@ def _get_regular_prompt(index: int): continue # init prompt dict prompt_dict = {"prompt": msg} - # collect multi_modal_data from images and audios + # collect multi_modal_data from images, audios, and videos multi_modal_data = {} images = data.get("vllm_images", None) if images is not None and len(images[i]) > 0: @@ -82,6 +82,11 @@ def _get_regular_prompt(index: int): multi_modal_data["audio"] = ( audios[i][0] if len(audios[i]) == 1 else audios[i] ) + videos = data.get("vllm_videos", None) + if videos is not None and len(videos[i]) > 0: + multi_modal_data["video"] = ( + videos[i][0] if len(videos[i]) == 1 else videos[i] + ) if not multi_modal_data: prompts.append(_get_regular_prompt(i)) continue diff --git a/pyrefly.toml b/pyrefly.toml index a1486f0a64..4d6e609dc6 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -59,6 +59,7 @@ project-includes = [ "nemo_rl/data/datasets/__init__.py", "nemo_rl/data/datasets/eval_datasets/__init__.py", "nemo_rl/data/datasets/eval_datasets/aime.py", + "nemo_rl/data/datasets/eval_datasets/daily_omni.py", "nemo_rl/data/datasets/eval_datasets/gpqa.py", "nemo_rl/data/datasets/eval_datasets/local_math_dataset.py", "nemo_rl/data/datasets/eval_datasets/math.py", diff --git a/tests/functional/L1_Functional_Tests_Megatron_1.sh b/tests/functional/L1_Functional_Tests_Megatron_1.sh index c7c6571aa3..e26f3e832f 100644 --- a/tests/functional/L1_Functional_Tests_Megatron_1.sh +++ b/tests/functional/L1_Functional_Tests_Megatron_1.sh @@ -35,6 +35,7 @@ run_test() { } run_test fast uv run --no-sync bash ./tests/functional/audio_grpo_megatron.sh +run_test uv run --no-sync bash ./tests/functional/audio_visual_grpo_megatron.sh run_test uv run --no-sync bash ./tests/functional/grpo_megatron.sh run_test uv run --no-sync bash ./tests/functional/grpo_megatron_mbridge_restore.sh run_test fast uv run --no-sync bash ./tests/functional/grpo_megatron_eagle3_online.sh diff --git a/tests/functional/audio_visual_grpo_megatron.sh b/tests/functional/audio_visual_grpo_megatron.sh new file mode 100644 index 0000000000..42f68283f7 --- /dev/null +++ b/tests/functional/audio_visual_grpo_megatron.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +# Audio+video intent recipe (IntentTrain): both <|VIDEO|> and <|AUDIO|> reach +# the model as independent multimodal items. Uses the 7B recipe config but +# pins the lighter Qwen2.5-Omni-3B so the functional test stays fast. +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_vlm_grpo.py \ + --config $PROJECT_ROOT/examples/configs/intent_grpo_7B_megatron.yaml \ + policy.model_name=Qwen/Qwen2.5-Omni-3B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=false \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/reward"]) > 0.6' \ + 'mean(data["train/token_mult_prob_error"]) < 1.05' diff --git a/tests/unit/data/datasets/test_intent_dataset.py b/tests/unit/data/datasets/test_intent_dataset.py new file mode 100644 index 0000000000..29148c0409 --- /dev/null +++ b/tests/unit/data/datasets/test_intent_dataset.py @@ -0,0 +1,34 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the IntentTrain / IntentBench dataset loader. + +The audio+video sample-shape contract (every prompt carries one +``{type:video}`` + one ``{type:audio}`` + a text item, so the chat template +emits both ``<|VIDEO|>`` and ``<|AUDIO|>`` placeholders) is exercised end to +end by the functional test ``tests/functional/audio_visual_grpo_megatron.sh`` +and by the vLLM-utils unit tests. The dedicated unit check for it required +``ffmpeg`` to fabricate an mp4 with an audio track, so it is intentionally not +included here — the unit suite stays ffmpeg-free. +""" + +import pytest + + +class TestIntentDataset: + def test_intent_invalid_split_raises(self): + from nemo_rl.data.datasets.response_datasets.intent import IntentDataset + + with pytest.raises(ValueError, match="Invalid split"): + IntentDataset(split="test") diff --git a/tests/unit/data/datasets/test_response_dataset.py b/tests/unit/data/datasets/test_response_dataset.py index d88524e27e..b7fdb37693 100644 --- a/tests/unit/data/datasets/test_response_dataset.py +++ b/tests/unit/data/datasets/test_response_dataset.py @@ -353,7 +353,8 @@ def test_dailyomni_dataset(): # check the content assert first_example["messages"][0]["role"] == "user" assert first_example["messages"][0]["content"][0]["type"] == "video" - assert first_example["messages"][0]["content"][1]["type"] == "text" + assert first_example["messages"][0]["content"][1]["type"] == "audio" + assert first_example["messages"][0]["content"][2]["type"] == "text" assert first_example["messages"][1]["role"] == "assistant" assert first_example["messages"][1]["content"] == "B" diff --git a/tests/unit/models/generation/test_vllm_utils.py b/tests/unit/models/generation/test_vllm_utils.py index 8e21abbf46..6126cde38f 100644 --- a/tests/unit/models/generation/test_vllm_utils.py +++ b/tests/unit/models/generation/test_vllm_utils.py @@ -71,6 +71,51 @@ def test_vllm_utils_vlm_with_images_and_text(): assert prompts[1]["multi_modal_data"]["image"] == ["img2a", "img2b"] +def test_vllm_utils_vlm_with_audio_and_video_intent_path(): + """IntentTrain/IntentBench rollouts must surface both modalities to vLLM. + + Asserts ``multi_modal_data`` contains a ``video`` key built from + ``vllm_videos`` AND an ``audio`` key built from ``vllm_audios`` for the + same prompt. This is the regression bar for AC-3 of the audio+video + intent recipe; if either key is dropped at this site, vLLM rolls out a + text-only / single-modality prompt and the smoke run silently degrades. + """ + input_ids, input_lengths = _mk_inputs() + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "vllm_content": ["user: q1", "user: q2"], + "vllm_videos": [["frames-1"], ["frames-2"]], + "vllm_audios": [[("audio-1", 16000)], [("audio-2", 16000)]], + "task_name": ["intent-train", "intent-bench"], + } + ) + + prompts = format_prompt_for_vllm_generation(data) + assert len(prompts) == 2 + for i, prompt in enumerate(prompts): + assert "multi_modal_data" in prompt, ( + f"prompt {i} missing multi_modal_data: keys={list(prompt)}" + ) + mm = prompt["multi_modal_data"] + assert "video" in mm, ( + f"prompt {i} dropped vllm_videos -> multi_modal_data['video']: " + f"keys={list(mm)}" + ) + assert "audio" in mm, ( + f"prompt {i} dropped vllm_audios -> multi_modal_data['audio']: " + f"keys={list(mm)}" + ) + # The independent-streams path explicitly does not set + # mm_processor_kwargs={"use_audio_in_video": True} (Round 1 BitLesson + # BL-20260428-omni-use-audio-in-video). If a future change re-introduces + # that flag this assertion will need to be updated together with vLLM + # acceptance evidence. + for prompt in prompts: + assert "mm_processor_kwargs" not in prompt + + def test_vllm_utils_vlm_with_missing_images_fallback_to_tokens(): input_ids, input_lengths = _mk_inputs() # images None triggers fallback