diff --git a/.cursor/rules/avoid-heap-allocations.mdc b/.cursor/rules/avoid-heap-allocations.mdc new file mode 100644 index 0000000..2ce523e --- /dev/null +++ b/.cursor/rules/avoid-heap-allocations.mdc @@ -0,0 +1,87 @@ +--- +description: Prefer no heap allocations; use references, generics, and slices where possible +globs: *.rs +alwaysApply: true +--- +This rule encourages avoiding heap allocations and temporary copies, especially in hot paths and library code. + +## General Principle + +Prefer stack allocation, references, and fixed-size or slice-based APIs over heap allocation. Avoid allocating or cloning when a reference, generic, or slice would suffice. + +## Guidelines + +### Prefer references over `clone()` + +- Do not call `.clone()` (or `.to_owned()`, `.to_string()`) just to satisfy a type when a reference (`&T`, `&str`) would work. +- Pass `&str` or `&[T]` through call chains instead of cloning into `String` or `Vec` unless ownership is required. +- Use `Cow<'_, T>` when a function might need either borrowed or owned data. + +**Good:** +```rust +fn format_names(infos: &[StackInfo]) -> String { + let mut s = String::new(); + for (i, info) in infos.iter().enumerate() { + if i > 0 { s.push_str(", "); } + s.push_str(&info.type_name); // or take &str in API + } + s +} +``` + +**Bad:** +```rust +let type_names: Vec = infos.iter().map(|i| i.type_name.clone()).collect(); +format(type_names.join(", ")) +``` + +### Prefer generics or function pointers over `Box` + +- Use generic parameters (`fn f(t: T)`) or function pointers (`fn(fn(A) -> B)`) when the set of types or functions is known at compile time and you do not need runtime extensibility. +- Reserve `Box` (or `&dyn Trait`) for cases where you truly need type erasure or a dynamic set of implementations. + +**Good:** +```rust +pub type ScopeFn = Box Result + Send + Sync>; +// Only when you need to store heterogeneous closures. Prefer: +fn with_callback Result>(f: F) { ... } +``` + +**Prefer when possible:** +```rust +fn apply(f: F) where F: Fn(u32) -> u32 { ... } +// or +type OpFn = fn(&mut DynSegment) -> Result<()>; +``` + +### Prefer slices over `Vec` for read-only or temporary views + +- Take or return `&[T]` (or `&mut [T]` when modifying in place) instead of `Vec` when the caller does not need ownership. +- Avoid allocating a `Vec` only to pass a slice: e.g. use a block that borrows from the source so the borrow ends before the next use (`let ok = { let s = x.peek(); s.len() == n };`). +- Use `slice.iter()` and work with references instead of collecting into a new `Vec` for matching or inspection. + +**Good:** +```rust +let matches = { + let top = segment.peek_stack_infos(num_operands); + top.len() == 2 && top[0].type_id == expected +}; +if matches { segment.apply_op()?; } +``` + +**Bad:** +```rust +let type_ids: Vec = segment.peek_stack_infos(n).iter().map(|i| i.type_id).collect(); +if type_ids.len() == 2 && type_ids[0] == expected { ... } +``` + +### Avoid unnecessary temporaries + +- Do not allocate a `Vec` or `String` just to build a single message or slice when a loop or iterator over the source can build the result directly (e.g. one `String` or no intermediate collection). +- Prefer `impl Iterator` or slice returns over returning a new `Vec` when the underlying data is already stored elsewhere. + +## Exceptions + +- Use `Vec` when you need an owned, growable sequence or when an API requires ownership. +- Use `Box` when you need type erasure, dynamic dispatch, or to store heterogeneous types in a collection. +- Use `clone()` when you genuinely need an independent copy (e.g. to pass across thread boundaries or to store in a structure that outlives the source). diff --git a/.cursor/rules/doc-comments.mdc b/.cursor/rules/doc-comments.mdc index 442905c..4c4c20a 100644 --- a/.cursor/rules/doc-comments.mdc +++ b/.cursor/rules/doc-comments.mdc @@ -1,21 +1,154 @@ --- -description: +description: Standards for Rust documentation comments globs: *.rs -alwaysApply: false +alwaysApply: true --- -This rule provides a standard for Rust doc comments. +This rule establishes standards for Rust documentation comments across all code. -Doc comments should be formatted with line comments in accordance with the [Rust Style -Guide](https://doc.rust-lang.org/stable/style-guide/#doc-comments). +## General Principles -All public components should have documentation following the conventions defined in [The RustDoc -Book](https://doc.rust-lang.org/rustdoc/how-to-write-documentation.html). +Doc comments should be formatted with line comments (`///` or `//!`) in accordance with the [Rust Style Guide](https://doc.rust-lang.org/stable/style-guide/#doc-comments). -Documentation for functions should always start with a short summary that describes what the -function does. Additional exposition is usually not necessary unless the function has preconditions -that may lead to an erroneous result or panic. +All public and private components **must** have documentation following the conventions defined in [The RustDoc Book](https://doc.rust-lang.org/rustdoc/how-to-write-documentation.html). -For functions that recognize a grammar production, the production is a sufficient comment. +## Structure -Module comments should contain more exposition and examples and serve as a tutorial for how to use -the components in the module. +### Summary Sentence +- **Always** start with a single-line summary that describes **what** the item does +- Use present tense ("Returns", "Creates", "Calculates" not "Will return", "Will create") +- Be concise and direct +- End with a period + +### Additional Sections (when needed) +These sections are only needed if they are not implied by the summary sentence. +- Add blank line after summary before additional details +- Use `# Arguments` for parameter descriptions +- Use `# Returns` for complex return value explanations +- Use `# Errors` to document error conditions +- Use `# Panics` to document panic conditions +- Use `# Safety` for unsafe functions + +## Specific Guidelines + +### Functions and Methods +- Start with what the function **does** (not what it "will do") +- Document preconditions that may cause errors or panics +- Follow recommendations from [Better Code: Contracts](https://github.com/stlab/better-code/blob/main/better-code/src/chapter-2-contracts.md) adapted to Rust conventions +- For grammar production parsers, the production itself is sufficient documentation + +**Good:** +```rust +/// `additive_expression = multiplicative_expression { ("+" | "-") multiplicative_expression }.` +fn is_additive_expression(&mut self) -> Result +``` + +**Bad:** +```rust +/// This function will parse additive expressions +fn is_additive_expression(&mut self) -> Result +``` + +### Structs and Enums +- Describe the purpose and role of the type +- State any invariants + +**Good:** +```rust +/// A scope-based operation lookup with stack support. +/// +/// Provides a stack of scopes for operation resolution, with built-in operations +/// as the fallback. Scopes are searched in LIFO order (most recently pushed first). +pub struct OpLookup { /* ... */ } +``` + +### Traits +- Describe what types implementing this trait represent +- Document trait semantics and contracts +- Provide examples of implementation + +### Modules +- Use `//!` for module-level documentation +- Provide comprehensive overview with context +- Include examples demonstrating typical usage patterns +- Serve as a tutorial for the module's components + +**Good:** +```rust +//! Operation table for dynamically dispatching operations based on type signatures. +//! +//! This module provides a scope-based registry for operations that can be looked up +//! based on an operation name (string) and the types of the operands. +//! +//! # Examples +//! ``` +//! // Show typical usage +//! ``` +``` + +### Type Aliases +- Explain what the alias represents +- Clarify why the alias exists (readability, semantics) + +### Constants +- Describe what the constant represents +- Include units or context if applicable + +## Error Documentation + +When functions can return errors: +- List specific error conditions in `# Errors` section +- Be explicit about **when** errors occur + +**Good:** +```rust +/// Looks up and applies an operation to the segment. +/// +/// # Errors +/// +/// Returns an error if no scope or built-in operation can handle the request. +pub fn lookup(&self, name: &str, types: &[TypeId], segment: &mut DynSegment) -> Result<()> +``` + +## Panic Documentation + +When functions can panic: +- Document **all** panic conditions in `# Panics` section +- Be specific about what causes the panic + +**Good:** +```rust +/// Returns the TypeId for this signature. +/// +/// # Panics +/// +/// Panics if the type_id_index is out of bounds (should never happen for valid signatures). +fn type_id(&self) -> TypeId +``` + +## Examples + +Include examples (`# Examples`) for: +- Public APIs +- Non-obvious usage patterns +- Types with specific initialization requirements +- Functions with multiple valid usage patterns + +Use triple backticks with `rust` for proper syntax highlighting. + +## Conciseness + +- Avoid redundant information (don't restate the obvious from signatures) +- Focus on **why** and **when**, not just **what** +- Keep it brief but complete + +## Cross-References + +- Use backticks for code elements: `TypeId`, `DynSegment`, `OpLookup` +- Use `[Type]` for links to other documented items +- Link to related functions/types when helpful + +## Special Cases + +- **Getters/Setters:** Brief description is sufficient ("Returns the X" / "Sets the X") +- **Grammar Productions:** The production itself is sufficient for parser functions +- **Internal/Private Items:** Documentation encouraged but not required; focus on public API clarity diff --git a/.vscode/settings.json b/.vscode/settings.json index bb0edda..f9a6335 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,6 +7,7 @@ "rust-analyzer.testExplorer": true, "testing.automaticallyOpenPeekView": "failureInVisibleDocument", "cSpell.words": [ - "concatenative" + "concatenative", + "Peekable" ] -} +} \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 19a5e48..728178e 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -1,17 +1,355 @@ -{ - "version": "2.0.0", - "tasks": [ - { - "type": "shell", - "label": "cargo test build", - "command": "cargo", - "args": [ - "test", - "--no-run" - ], - "problemMatcher": [ - "$rustc" - ] - } - ] -} +{ + "version": "2.0.0", + "tasks": [ + // ============ Build Tasks ============ + { + "label": "cargo build", + "type": "cargo", + "command": "build", + "problemMatcher": [ + "$rustc" + ], + "group": { + "kind": "build", + "isDefault": true + } + }, + { + "label": "cargo test", + "type": "cargo", + "command": "test", + "problemMatcher": [ + "$rustc" + ], + "group": { + "kind": "test", + "isDefault": true + } + }, + { + "label": "cargo test build (no run)", + "type": "cargo", + "command": "test", + "args": [ + "--no-run" + ], + "problemMatcher": [ + "$rustc" + ] + }, + { + "label": "cargo check", + "type": "cargo", + "command": "check", + "problemMatcher": [ + "$rustc" + ] + }, + // ============ Documentation Tasks ============ + { + "label": "cargo doc (open)", + "type": "cargo", + "command": "doc", + "args": [ + "--lib", + "--no-deps", + "--open", + "--workspace" + ], + "problemMatcher": [ + "$rustc" + ] + }, + { + "label": "cargo doc", + "type": "cargo", + "command": "doc", + "args": [ + "--lib", + "--no-deps", + "--workspace" + ], + "problemMatcher": [ + "$rustc" + ] + }, + { + "label": "cargo test --doc", + "type": "cargo", + "command": "test", + "args": [ + "--doc", + "--workspace" + ], + "problemMatcher": [ + "$rustc" + ] + }, + // ============ Clippy Tasks ============ + { + "label": "cargo clippy", + "type": "cargo", + "command": "clippy", + "args": [ + "--workspace" + ], + "problemMatcher": [ + "$rustc" + ] + }, + { + "label": "cargo clippy --fix", + "type": "cargo", + "command": "clippy", + "args": [ + "--fix", + "--workspace" + ], + "problemMatcher": [ + "$rustc" + ] + }, + // ============ Package-Specific Tests ============ + { + "label": "cargo test (parser)", + "type": "cargo", + "command": "test", + "args": [ + "--package", + "cel-parser" + ], + "problemMatcher": [ + "$rustc" + ] + }, + { + "label": "cargo test (runtime)", + "type": "cargo", + "command": "test", + "args": [ + "--package", + "cel-runtime" + ], + "problemMatcher": [ + "$rustc" + ] + }, + // ============ Sanitizers - macOS/Linux ============ + { + "label": "sanitizer: address (macOS)", + "type": "shell", + "command": "cargo", + "args": [ + "+nightly", + "test", + "-Zbuild-std", + "--target", + "x86_64-apple-darwin", + "--workspace" + ], + "options": { + "env": { + "RUST_BACKTRACE": "1", + "RUSTFLAGS": "-Zsanitizer=address", + "RUSTDOCFLAGS": "-Zsanitizer=address" + } + }, + "problemMatcher": [ + "$rustc" + ], + "presentation": { + "reveal": "always", + "panel": "dedicated" + } + }, + { + "label": "sanitizer: address lib only (macOS)", + "type": "shell", + "command": "cargo", + "args": [ + "+nightly", + "test", + "-Zbuild-std", + "--target", + "x86_64-apple-darwin", + "--lib", + "--workspace" + ], + "options": { + "env": { + "RUST_BACKTRACE": "1", + "RUSTFLAGS": "-Zsanitizer=address" + } + }, + "problemMatcher": [ + "$rustc" + ], + "presentation": { + "reveal": "always", + "panel": "dedicated" + } + }, + { + "label": "sanitizer: leak (macOS)", + "type": "shell", + "command": "cargo", + "args": [ + "+nightly", + "test", + "-Zbuild-std", + "--target", + "x86_64-apple-darwin", + "--workspace" + ], + "options": { + "env": { + "RUST_BACKTRACE": "1", + "RUSTFLAGS": "-Zsanitizer=leak", + "RUSTDOCFLAGS": "-Zsanitizer=leak" + } + }, + "problemMatcher": [ + "$rustc" + ], + "presentation": { + "reveal": "always", + "panel": "dedicated" + } + }, + { + "label": "sanitizer: thread (macOS)", + "type": "shell", + "command": "cargo", + "args": [ + "+nightly", + "test", + "-Zbuild-std", + "--target", + "x86_64-apple-darwin", + "--workspace" + ], + "options": { + "env": { + "RUST_BACKTRACE": "1", + "RUSTFLAGS": "-Zsanitizer=thread", + "RUSTDOCFLAGS": "-Zsanitizer=thread" + } + }, + "problemMatcher": [ + "$rustc" + ], + "presentation": { + "reveal": "always", + "panel": "dedicated" + } + }, + { + "label": "sanitizer: address (Linux/WSL2)", + "type": "shell", + "command": "${env:HOME}/.cargo/bin/cargo", + "args": [ + "+nightly", + "test", + "-Zbuild-std", + "--target", + "x86_64-unknown-linux-gnu", + "--workspace" + ], + "options": { + "env": { + "RUST_BACKTRACE": "1", + "RUSTFLAGS": "-Zsanitizer=address", + "RUSTDOCFLAGS": "-Zsanitizer=address" + } + }, + "problemMatcher": [ + "$rustc" + ], + "presentation": { + "reveal": "always", + "panel": "dedicated" + } + }, + { + "label": "sanitizer: address lib only (Linux/WSL2)", + "type": "shell", + "command": "${env:HOME}/.cargo/bin/cargo", + "args": [ + "+nightly", + "test", + "-Zbuild-std", + "--target", + "x86_64-unknown-linux-gnu", + "--lib", + "--workspace" + ], + "options": { + "env": { + "RUST_BACKTRACE": "1", + "RUSTFLAGS": "-Zsanitizer=address" + } + }, + "problemMatcher": [ + "$rustc" + ], + "presentation": { + "reveal": "always", + "panel": "dedicated" + } + }, + { + "label": "sanitizer: leak (Linux/WSL2)", + "type": "shell", + "command": "${env:HOME}/.cargo/bin/cargo", + "args": [ + "+nightly", + "test", + "-Zbuild-std", + "--target", + "x86_64-unknown-linux-gnu", + "--workspace" + ], + "options": { + "env": { + "RUST_BACKTRACE": "1", + "RUSTFLAGS": "-Zsanitizer=leak", + "RUSTDOCFLAGS": "-Zsanitizer=leak" + } + }, + "problemMatcher": [ + "$rustc" + ], + "presentation": { + "reveal": "always", + "panel": "dedicated" + } + }, + { + "label": "sanitizer: thread (Linux/WSL2)", + "type": "shell", + "command": "${env:HOME}/.cargo/bin/cargo", + "args": [ + "+nightly", + "test", + "-Zbuild-std", + "--target", + "x86_64-unknown-linux-gnu", + "--workspace" + ], + "options": { + "env": { + "RUST_BACKTRACE": "1", + "RUSTFLAGS": "-Zsanitizer=thread", + "RUSTDOCFLAGS": "-Zsanitizer=thread" + } + }, + "problemMatcher": [ + "$rustc" + ], + "presentation": { + "reveal": "always", + "panel": "dedicated" + } + } + ] +} \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 5ad27b9..8913fe1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,21 +8,10 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" -[[package]] -name = "cel-parser" -version = "0.1.0" -dependencies = [ - "litrs", - "owo-colors", - "proc-macro2", - "quote", -] - [[package]] name = "cel-rs" version = "0.1.0" dependencies = [ - "cel-parser", "cel-rs-macros", "cel-runtime", "proc-macro2", @@ -32,7 +21,7 @@ dependencies = [ name = "cel-rs-macros" version = "0.1.0" dependencies = [ - "cel-parser", + "cel-runtime", "proc-macro2", "quote", ] @@ -42,7 +31,12 @@ name = "cel-runtime" version = "0.1.0" dependencies = [ "anyhow", - "cel-rs-macros", + "once_cell", + "owo-colors", + "phf", + "proc-macro2", + "quote", + "syn", "typenum", ] @@ -76,13 +70,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" [[package]] -name = "litrs" -version = "0.4.2" +name = "once_cell" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5e54036fe321fd421e10d732f155734c4e4afd610dd556d9a82833ab3ee0bed" -dependencies = [ - "proc-macro2", -] +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "owo-colors" @@ -94,6 +85,48 @@ dependencies = [ "supports-color 3.0.2", ] +[[package]] +name = "phf" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" +dependencies = [ + "phf_macros", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_macros" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" +dependencies = [ + "phf_generator", + "phf_shared", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher", +] + [[package]] name = "proc-macro2" version = "1.0.95" @@ -112,6 +145,27 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "supports-color" version = "2.1.0" @@ -131,6 +185,17 @@ dependencies = [ "is_ci", ] +[[package]] +name = "syn" +version = "2.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "typenum" version = "1.18.0" diff --git a/Cargo.toml b/Cargo.toml index d95636f..1b5d984 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,15 +6,13 @@ description = "A stack-based runtime for developing domain specific languages" [dependencies] cel-runtime = { path = "./cel-runtime" } -cel-parser = { path = "./cel-parser" } cel-rs-macros = { path = "./cel-rs-macros" } proc-macro2 = "1.0" [workspace] members = [ "cel-runtime", - "cel-rs-macros", - "cel-parser" + "cel-rs-macros" ] [lib] diff --git a/cel-parser/Cargo.toml b/cel-parser/Cargo.toml deleted file mode 100644 index 2b75e59..0000000 --- a/cel-parser/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "cel-parser" -version = "0.1.0" -edition = "2024" -description = "Parser for Common Expression Language (CEL) expressions" -license = "MIT" - -[dependencies] -proc-macro2 = { version = "1.0", features = ["span-locations"] } -quote = "1.0" -owo-colors = { version = "4.2", features = ["supports-colors"] } -litrs = "0.4.1" diff --git a/cel-parser/src/lib.rs b/cel-parser/src/lib.rs deleted file mode 100644 index bb857df..0000000 --- a/cel-parser/src/lib.rs +++ /dev/null @@ -1,768 +0,0 @@ -#![warn(missing_docs)] - -//! A recursive descent parser for CEL (Common Expression Language) expressions. -//! -//! This crate provides a parser that can parse CEL expressions into token streams -//! suitable for use in procedural macros. The parser follows the CEL grammar -//! specification and provides detailed error reporting with source location information. -//! -//! # Grammar -//! -//! ```text -//! expression = or_expression . -//! or_expression = and_expression { "||" and_expression }. -//! and_expression = comparison_expression { "&&" comparison_expression }. -//! comparison_expression = bitwise_or_expression [ ("==" | "!=" | "<" | ">" | "<=" | ">=") bitwise_or_expression ]. -//! bitwise_or_expression = bitwise_xor_expression { "|" bitwise_xor_expression }. -//! bitwise_xor_expression = bitwise_and_expression { "^" bitwise_and_expression }. -//! bitwise_and_expression = bitwise_shift_expression { "&" bitwise_shift_expression }. -//! bitwise_shift_expression = additive_expression { ("<<" | ">>") additive_expression }. -//! additive_expression = multiplicative_expression { ("+" | "-") multiplicative_expression }. -//! multiplicative_expression = unary_expression { ("*" | "/" | "%") unary_expression }. -//! unary_expression = (("-" | "!") unary_expression) | primary_expression. -//! primary_expression = literal | identifier | "(" expression ")". -//! ``` -//! -//! # Examples -//! -//! ## Basic Usage -//! -//! ```rust -//! use cel_parser::CELParser; -//! use proc_macro2::TokenStream; -//! use std::str::FromStr; -//! -//! let input = TokenStream::from_str("10 + 20").unwrap(); -//! let mut parser = CELParser::new(input.into_iter()); -//! assert!(parser.is_expression()); -//! ``` -//! -//! ## Error Formatting -//! -//! ```rust -//! use cel_parser::CELParser; -//! use proc_macro2::TokenStream; -//! use std::str::FromStr; -//! -//! let line = line!() + 1; -//! let source = r#" -//! 10 + 20 30 -//! "#; // Invalid: missing operator -//! let input = TokenStream::from_str(source).unwrap(); -//! let mut parser = CELParser::new(input.into_iter()); -//! -//! if !parser.is_expression() { -//! // Format error starting at line 1 -//! if let Some(formatted_error) = parser.format_error(source, file!(), line) { -//! println!("{}", formatted_error); -//! // Output: -//! // error: Unexpected token -//! // --> example.cel:1:8 -//! // | -//! // 1 | 10 + 20 30 -//! // | ^^ -//! } -//! } -//! ``` - -use litrs::StringLit; -use owo_colors::OwoColorize; -use proc_macro2::{Delimiter, Spacing, Span, TokenStream, TokenTree}; -use quote::quote_spanned; -use std::iter::Peekable; - -/// A recursive descent parser for expressions. -/// -/// Grammar: -/// ```text -/// expression = or_expression . -/// or_expression = and_expression { "||" and_expression }. -/// and_expression = comparison_expression { "&&" comparison_expression }. -/// comparison_expression = bitwise_or_expression [ ("==" | "!=" | "<" | ">" | "<=" | ">=") bitwise_or_expression ]. -/// bitwise_or_expression = bitwise_xor_expression { "|" bitwise_xor_expression }. -/// bitwise_xor_expression = bitwise_and_expression { "^" bitwise_and_expression }. -/// bitwise_and_expression = bitwise_shift_expression { "&" bitwise_shift_expression }. -/// bitwise_shift_expression = additive_expression { ("<<" | ">>") additive_expression }. -/// additive_expression = multiplicative_expression { ("+" | "-") multiplicative_expression }. -/// multiplicative_expression = unary_expression { ("*" | "/" | "%") unary_expression }. -/// unary_expression = (("-" | "!") unary_expression) | primary_expression. -/// primary_expression = literal | identifier | "(" expression ")". -/// ``` -/// -/// # Examples -/// -/// ## Basic Usage -/// -/// ```rust -/// use cel_parser::CELParser; -/// use proc_macro2::TokenStream; -/// use std::str::FromStr; -/// -/// let input = TokenStream::from_str("10 + 20").unwrap(); -/// let mut parser = CELParser::new(input.into_iter()); -/// assert!(parser.is_expression()); -/// ``` -/// -/// ## Error Formatting -/// -/// ```rust -/// use cel_parser::CELParser; -/// use proc_macro2::TokenStream; -/// use std::str::FromStr; -/// -/// let line = line!() + 1; -/// let source = r#" -/// 10 + 20 30 -/// "#; // Invalid: missing operator -/// let input = TokenStream::from_str(source).unwrap(); -/// let mut parser = CELParser::new(input.into_iter()); -/// -/// if !parser.is_expression() { -/// // Format error starting at line 1 -/// if let Some(formatted_error) = parser.format_error(source, file!(), line) { -/// println!("{}", formatted_error); -/// // Output: -/// // error: Unexpected token -/// // --> example.cel:1:8 -/// // | -/// // 1 | 10 + 20 30 -/// // | ^^ -/// } -/// } -/// ``` -pub struct CELParser> { - tokens: Peekable, - output: TokenStream, -} - -impl + Clone> CELParser { - /// Extracts the error message from the parser's output token stream. - /// - /// This method searches for a `compile_error!` macro call in the output - /// and extracts the string literal argument as the error message. - /// - /// # Returns - /// - /// Returns `Some(message)` if an error message was found, or `None` if - /// no error was present in the output. - pub fn extract_error_message(&self) -> Option { - let mut tokens = self.output.clone().into_iter(); - - while let Some(token) = tokens.next() { - if let TokenTree::Ident(ident) = token - && ident == "compile_error" - && let Some(TokenTree::Punct(punct)) = tokens.next() - && punct.as_char() == '!' - && let Some(TokenTree::Group(group)) = tokens.next() - && group.delimiter() == Delimiter::Parenthesis - { - let mut group_tokens = group.stream().into_iter(); - if let Some(TokenTree::Literal(lit)) = group_tokens.next() { - // Clean extraction using litrs - if let Ok(string_lit) = StringLit::try_from(lit) { - return Some(string_lit.value().to_string()); - } - } - } - } - None - } - - /// - pub fn format_error( - &self, - source_code: &str, - filename: &str, - start_line: u32, - ) -> Option { - if let Some(error_msg) = self.extract_error_message() - && let Some(span) = self.get_error_span() - { - return Some(self.format_rustc_style( - &error_msg, - span, - source_code, - filename, - start_line, - )); - } - - None - } - - fn format_rustc_style( - &self, - message: &str, - span: Span, - source: &str, - filename: &str, - start_line: u32, - ) -> String { - let start = span.start(); - let end = span.end(); - - let lines: Vec<&str> = source.lines().collect(); - - let mut output = String::new(); - - // Calculate offset line numbers (start_line is 1-based) - let error_line = start_line + (start.line as u32) - 1; - let error_column = start.column + 1; // +1 because the column is 0-based but the error is 1-based - - // Calculate the width needed for line numbers - // end.line is the last line within the source span (1-based) - // start_line is the offset to get actual file line numbers - // The maximum displayed line number will be: start_line + end.line - 1 - let max_line_num = start_line + (end.line as u32) - 1; - let line_width = max_line_num.to_string().len(); - - // Error header with red and bold "error:" - output.push_str(&format!("{}: {}\n", "error".red().bold(), message)); - output.push_str(&format!( - " {} {}:{}:{}\n", - "-->".blue().bold(), - filename.blue(), - error_line.to_string().blue(), - error_column.to_string().blue() - )); - output.push_str(&format!( - "{:width$} {}\n", - "", - "|".blue().bold(), - width = line_width - )); - - // Show the problematic line(s) - for line_num in start.line..=end.line { - if let Some(line_content) = lines.get(line_num.saturating_sub(1)) { - let display_line_num = start_line + (line_num as u32) - 1; - output.push_str(&format!( - "{} {} {}\n", - display_line_num.to_string().blue().bold(), - "|".blue().bold(), - line_content - )); - - // Add caret indicators - if line_num == start.line { - output.push_str(&format!( - "{:width$} {} ", - "", - "|".blue().bold(), - width = line_width - )); - - // Add spaces up to start column - output.push_str(&" ".repeat(start.column)); - - // Add carets in red - let caret_len = if start.line == end.line { - end.column.saturating_sub(start.column).max(1) - } else { - line_content - .len() - .saturating_sub(start.column.saturating_sub(1)) - }; - - output.push_str(&"^".repeat(caret_len).red().bold().to_string()); - output.push('\n'); - } - } - } - - output - } - - fn get_error_span(&self) -> Option { - // The compile_error! TokenStream structure is: - // TokenTree::Ident("compile_error") - with the span we want - // TokenTree::Punct('!') - // TokenTree::Group(...) - containing the message, also with the span - - let mut tokens = self.output.clone().into_iter(); - - // Look for the first token (should be "compile_error" ident) - if let Some(first_token) = tokens.next() { - match first_token { - TokenTree::Ident(ident) if ident == "compile_error" => { - return Some(ident.span()); - } - _ => { - // Fallback: try to get span from any token in the stream - return Some(first_token.span()); - } - } - } - None - } - - /// Creates a new CEL parser with the given token iterator. - /// - /// # Arguments - /// - /// * `tokens` - An iterator over `TokenTree` items to parse - /// - /// # Returns - /// - /// A new `CELParser` instance ready to parse the tokens. - pub fn new(tokens: I) -> Self { - let output = TokenStream::new(); - CELParser { - tokens: tokens.peekable(), - output, - } - } - - /// Returns a reference to the parser's output token stream. - /// - /// This contains the parsed tokens or error information if parsing failed. - /// - /// # Returns - /// - /// A reference to the output `TokenStream`. - pub fn get_output(&self) -> &TokenStream { - &self.output - } - - fn advance(&mut self) { - self.tokens.next(); - } - - /// Reports a parsing error by adding a `compile_error!` macro to the output. - /// - /// This method creates a compile-time error with the given message at the - /// current token's span location. - /// - /// # Arguments - /// - /// * `message` - The error message to report - /// - /// # Returns - /// - /// Always returns `false` to indicate parsing failure. - pub fn report_error(&mut self, message: &str) -> bool { - let span = self - .tokens - .peek() - .map_or_else(proc_macro2::Span::call_site, |token| token.span()); - self.output = quote_spanned!(span => compile_error!(#message)); - false - } - - fn is_one_of_punc(token: Option<&TokenTree>, sequence: &[char]) -> bool { - match token { - Some(TokenTree::Punct(punct)) => sequence.contains(&punct.as_char()), - _ => false, - } - } - - fn is_punctuation(&mut self, string: &str) -> bool { - let mut tmp = self.tokens.clone(); - let mut spacing = Spacing::Joint; - for c in string.chars() { - if spacing == Spacing::Alone { - return false; - } - match tmp.peek() { - Some(TokenTree::Punct(punct)) => { - if punct.as_char() != c { - return false; - } - spacing = punct.spacing(); - tmp.next(); - } - _ => return false, - } - } - // filter false positives for compound operators - if spacing == Spacing::Joint && string.len() == 1 { - let compound_chars = [ - ('&', &['&'][..]), - ('|', &['|'][..]), - ('<', &['<', '='][..]), - ('>', &['>', '='][..]), - ]; - let c = string.chars().next().unwrap(); // safe since string.len() == 1 - - if let Some((_, next_chars)) = compound_chars.iter().find(|(ch, _)| *ch == c) - && Self::is_one_of_punc(tmp.peek(), next_chars) - { - return false; - } - } - self.tokens = tmp; - true - } - - fn is_one_of_punctuation(&mut self, sequence: &[&str]) -> bool { - for s in sequence { - if self.is_punctuation(s) { - return true; - } - } - false - } - - /// `expression = or_expression .` - pub fn is_expression(&mut self) -> bool { - if !self.is_or_expression() { - return false; - } - if self.tokens.peek().is_some() { - return self.report_error("unexpected token"); - } - true - } - - /// `or_expression = and_expression { "||" and_expression }.` - fn is_or_expression(&mut self) -> bool { - if self.is_and_expression() { - while self.is_one_of_punctuation(&["||"]) { - if !self.is_and_expression() { - return self.report_error("expected and_expression"); - } - } - true - } else { - false - } - } - - /// `and_expression = comparison_expression { "&&" comparison_expression }.` - fn is_and_expression(&mut self) -> bool { - if self.is_comparison_expression() { - while self.is_one_of_punctuation(&["&&"]) { - if !self.is_comparison_expression() { - return self.report_error("expected comparison_expression"); - } - } - true - } else { - false - } - } - - /// `comparison_expression = bitwise_or_expression [ ("==" | "!=" | "<" | ">" | "<=" | ">=") bitwise_or_expression ].` - fn is_comparison_expression(&mut self) -> bool { - if self.is_bitwise_or_expression() { - if self.is_one_of_punctuation(&["==", "!=", "<", ">", "<=", ">="]) - && !self.is_bitwise_or_expression() - { - return self.report_error("expected bitwise_or_expression"); - } - true - } else { - false - } - } - - /// `bitwise_or_expression = bitwise_xor_expression { "|" bitwise_xor_expression }.` - fn is_bitwise_or_expression(&mut self) -> bool { - if self.is_bitwise_xor_expression() { - while self.is_one_of_punctuation(&["|"]) { - if !self.is_bitwise_xor_expression() { - return self.report_error("expected bitwise_xor_expression"); - } - } - true - } else { - false - } - } - - /// `bitwise_xor_expression = bitwise_and_expression { "^" bitwise_and_expression }.` - fn is_bitwise_xor_expression(&mut self) -> bool { - if self.is_bitwise_and_expression() { - while self.is_one_of_punctuation(&["^"]) { - if !self.is_bitwise_and_expression() { - return self.report_error("expected bitwise_and_expression"); - } - } - true - } else { - false - } - } - - /// `bitwise_and_expression = bitwise_shift_expression { "&" bitwise_shift_expression }.` - fn is_bitwise_and_expression(&mut self) -> bool { - if self.is_bitwise_shift_expression() { - while self.is_one_of_punctuation(&["&"]) { - if !self.is_bitwise_shift_expression() { - return self.report_error("expected bitwise_shift_expression"); - } - } - true - } else { - false - } - } - - /// `bitwise_shift_expression = additive_expression { ("<<" | ">>") additive_expression }.` - fn is_bitwise_shift_expression(&mut self) -> bool { - if self.is_additive_expression() { - while self.is_one_of_punctuation(&["<<", ">>"]) { - if !self.is_additive_expression() { - return self.report_error("expected additive_expression"); - } - } - true - } else { - false - } - } - - /// `additive_expression = multiplicative_expression { ("+" | "-") multiplicative_expression }.` - fn is_additive_expression(&mut self) -> bool { - if self.is_multiplicative_expression() { - while self.is_one_of_punctuation(&["+", "-"]) { - if !self.is_multiplicative_expression() { - return self.report_error("expected multiplicative_expression"); - } - } - true - } else { - false - } - } - - /// `multiplicative_expression = unary_expression { ("*" | "/" | "%") unary_expression }.` - fn is_multiplicative_expression(&mut self) -> bool { - if self.is_unary_expression() { - while self.is_one_of_punctuation(&["*", "/", "%"]) { - if !self.is_unary_expression() { - return self.report_error("expected unary_expression"); - } - } - true - } else { - false - } - } - - /// `unary_expression = (("-" | "!") unary_expression) | primary_expression.` - fn is_unary_expression(&mut self) -> bool { - if self.is_one_of_punctuation(&["-", "!"]) { - if !self.is_unary_expression() { - return self.report_error("expected unary_expression"); - } - true - } else { - self.is_primary_expression() - } - } - - /// `primary_expression = literal | identifier | "(" expression ")".` - fn is_primary_expression(&mut self) -> bool { - match self.tokens.peek() { - Some(TokenTree::Literal(_)) => { - self.advance(); - true - } - Some(TokenTree::Ident(_)) => { - self.advance(); - true - } - Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Parenthesis => { - let mut parser = CELParser::new(group.stream().into_iter()); - if parser.is_expression() { - self.advance(); - true - } else { - false - } - } - _ => false, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use proc_macro2::TokenStream; - use std::str::FromStr; - - #[test] - fn simple_expression() { - let input = TokenStream::from_str("10").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(parser.is_expression()); - } - - #[test] - fn incomplete_expression() { - let input = TokenStream::from_str("10 + 25 25").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(!parser.is_expression()); - assert_eq!( - parser.output.to_string(), - "compile_error ! (\"unexpected token\")" - ); - } - - #[test] - fn arithmetic_expression() { - let input = TokenStream::from_str("10 + 20 * 30").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(parser.is_expression()); - } - - #[test] - fn parenthesized_expression() { - let input = TokenStream::from_str("(10 + 20) * 30").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(parser.is_expression()); - } - - #[test] - fn complex_expression() { - let input = TokenStream::from_str("10 + 20 * (30 - 5) / 2").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(parser.is_expression()); - } - - #[test] - fn logical_expression() { - let input = TokenStream::from_str("a && b || c").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(parser.is_expression()); - } - - #[test] - fn comparison_expression() { - let input = TokenStream::from_str("a == b && c > d").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(parser.is_expression()); - } - - #[test] - fn bitwise_expression() { - let input = TokenStream::from_str("a | b & c ^ d").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(parser.is_expression()); - } - - #[test] - fn shift_expression() { - let input = TokenStream::from_str("a << 2 + b >> 1").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(parser.is_expression()); - } - - #[test] - fn unary_expression() { - let input = TokenStream::from_str("-a + !b").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(parser.is_expression()); - } - - #[test] - fn chained_unary_expression() { - let input = TokenStream::from_str("!!a + --b").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(parser.is_expression()); - } - - #[test] - fn invalid_expression() { - let input = TokenStream::from_str("+").unwrap(); - let mut parser = CELParser::new(input.into_iter()); - assert!(!parser.is_expression()); - } - - /// Helper function to strip ANSI escape codes from a string for testing purposes - fn strip_ansi_codes(input: &str) -> String { - // Basic regex to remove ANSI escape sequences - // ANSI escape sequences start with ESC (0x1B) followed by '[' and end with a letter - let mut result = String::new(); - let mut chars = input.chars().peekable(); - - while let Some(ch) = chars.next() { - if ch == '\x1B' { - // Found ESC, check if it's followed by '[' - if chars.peek() == Some(&'[') { - chars.next(); // consume '[' - // Skip until we find a letter (which ends the escape sequence) - while let Some(ch) = chars.next() { - if ch.is_ascii_alphabetic() { - break; - } - } - } else { - result.push(ch); - } - } else { - result.push(ch); - } - } - - result - } - - #[test] - fn error_formatting() { - let source = "10 + 20 30"; // Missing operator between 20 and 30 - let input = TokenStream::from_str(source).unwrap(); - let mut parser = CELParser::new(input.into_iter()); - - // This should fail parsing - assert!(!parser.is_expression()); - - // Test error message extraction - let error_msg = parser.extract_error_message(); - assert!(error_msg.is_some()); - assert_eq!(error_msg.unwrap(), "unexpected token"); - - // Test error formatting - let formatted_error = parser.format_error(source, "test.cel", 1u32); - assert!(formatted_error.is_some()); - - // Strip ANSI codes for testing - let formatted = strip_ansi_codes(&formatted_error.unwrap()); - assert!(formatted.contains("error: unexpected token")); - assert!(formatted.contains("test.cel:1:")); // Should include line number - assert!(formatted.contains("1 | 10 + 20 30")); // Should show the line with line number - assert!(formatted.contains("^")); // Should have carets pointing to the error - } - - #[test] - fn error_formatting_with_line_offset() { - let source = "a + b c"; // Missing operator between b and c - let input = TokenStream::from_str(source).unwrap(); - let mut parser = CELParser::new(input.into_iter()); - - // This should fail parsing - assert!(!parser.is_expression()); - - // Test error formatting with line offset (as if expression starts at line 42) - let formatted_error = parser.format_error(source, "large_file.rs", 42u32); - assert!(formatted_error.is_some()); - - // Strip ANSI codes for testing - let formatted = strip_ansi_codes(&formatted_error.unwrap()); - assert!(formatted.contains("error: unexpected token")); - assert!(formatted.contains("large_file.rs:42:")); // Should show offset line number - assert!(formatted.contains("42 | a + b c")); // Should show the line with offset line number - assert!(formatted.contains("^")); // Should have carets pointing to the error - } - - #[test] - fn print_error_formatting() { - let line = line!() + 1; - let source = r#" - - 10 + 20 30 // Unexpected token - - "#; - - let input = TokenStream::from_str(source).unwrap(); - let mut parser = CELParser::new(input.into_iter()); - - if !parser.is_expression() { - if let Some(formatted_error) = parser.format_error(source, file!(), line) { - println!("{}", formatted_error); - } - } - } -} diff --git a/cel-rs-macros/Cargo.toml b/cel-rs-macros/Cargo.toml index 4d48ebb..a9b168f 100644 --- a/cel-rs-macros/Cargo.toml +++ b/cel-rs-macros/Cargo.toml @@ -11,4 +11,4 @@ proc-macro = true [dependencies] quote = "1.0" proc-macro2 = "1.0" -cel-parser = { path = "../cel-parser" } +cel-runtime = { path = "../cel-runtime" } diff --git a/cel-rs-macros/src/lib.rs b/cel-rs-macros/src/lib.rs index beb0a18..e1bbd94 100644 --- a/cel-rs-macros/src/lib.rs +++ b/cel-rs-macros/src/lib.rs @@ -30,9 +30,10 @@ //! }; //! ``` -use cel_parser::CELParser; +use cel_runtime::{CELError, CELParser, OpLookup}; use proc_macro::TokenStream as ProcMacroTokenStream; -use proc_macro2::TokenStream; +use proc_macro2::{Literal, TokenStream}; +use quote::quote_spanned; /// Validates that the input contains a valid CEL expression. /// @@ -45,11 +46,23 @@ use proc_macro2::TokenStream; #[proc_macro] pub fn expression(input: ProcMacroTokenStream) -> ProcMacroTokenStream { let input = TokenStream::from(input); - let mut parser = CELParser::new(input.into_iter()); - if !parser.is_expression() { - parser.report_error("Expected expression"); + let mut parser = CELParser::new(OpLookup::new()); + parser.set_tokens(input.into_iter()); + match parser.is_expression() { + Ok(true) => ProcMacroTokenStream::new(), + Ok(false) => { + let e = CELError::new( + "Expected expression", + cel_runtime::parser::SourceSpan::default(), + ); + let msg_lit = Literal::string(&e.to_string()); + quote_spanned!(proc_macro2::Span::call_site() => compile_error!(#msg_lit)).into() + } + Err(e) => { + let msg_lit = Literal::string(&e.to_string()); + quote_spanned!(proc_macro2::Span::call_site() => compile_error!(#msg_lit)).into() + } } - parser.get_output().clone().into() } /// Prints the tokens for debugging purposes. @@ -58,7 +71,7 @@ pub fn expression(input: ProcMacroTokenStream) -> ProcMacroTokenStream { /// ```rust /// use cel_rs_macros::print_tokens; /// print_tokens! { -/// 10 +/// "hello"_key /// }; /// ``` #[proc_macro] diff --git a/cel-runtime/Cargo.toml b/cel-runtime/Cargo.toml index 41987ba..b464115 100644 --- a/cel-runtime/Cargo.toml +++ b/cel-runtime/Cargo.toml @@ -7,4 +7,9 @@ description = "A stack-based runtime for developing domain specific languages" [dependencies] anyhow = "1.0" typenum = "1.18.0" -cel-rs-macros = { path = "../cel-rs-macros" } +proc-macro2 = { version = "1.0", features = ["span-locations"] } +quote = "1.0" +owo-colors = { version = "4.2", features = ["supports-colors"] } +syn = { version = "2.0", features = ["extra-traits", "parsing"] } +phf = { version = "0.11", features = ["macros"] } +once_cell = "1.19" \ No newline at end of file diff --git a/cel-runtime/src/dyn_segment.rs b/cel-runtime/src/dyn_segment.rs index cea5c64..933a22d 100644 --- a/cel-runtime/src/dyn_segment.rs +++ b/cel-runtime/src/dyn_segment.rs @@ -6,17 +6,39 @@ use crate::raw_stack::RawStack; use crate::{CStackListHeadLimit, CStackListHeadPadded, ReverseList}; use anyhow::Result; use anyhow::ensure; +use std::borrow::Cow; use std::any::TypeId; use std::cmp::max; +/// Recursive type node carrying a [`TypeId`], display name, and optional associated types. +/// +/// Used for function parameter/return types, tuple elements, and similar structure. +/// Not yet used; reserved for parse-time call checking and richer error reporting. +#[derive(Clone, Debug)] +pub struct AssociatedType { + /// Runtime type id for this node. + pub type_id: TypeId, + /// Human-readable name for error reporting (borrowed when from `type_name::()`). + pub type_name: Cow<'static, str>, + /// Child types (e.g. function parameters, tuple elements). + pub associated: Vec, +} + /// Information about a type on the stack, including its cleanup function. /// -/// This struct holds metadata about a type that has been pushed onto the stack, -/// including how to properly drop it when the stack is unwound. +/// Holds metadata for a value pushed onto the stack: runtime type id, display name +/// for errors, padding, dropper, and an optional list of associated types. pub struct StackInfo { - pub(crate) stack_id: TypeId, - stack_unwind: Dropper, - padded: bool, + /// Runtime type id for this stack slot (e.g. for scope matching). + pub type_id: TypeId, + /// Human-readable type name for error reporting (borrowed when from `type_name::()`). + pub type_name: Cow<'static, str>, + /// Whether padding was inserted before this value for alignment. + pub(crate) padding: bool, + /// Dropper used when unwinding the stack. + dropper: Dropper, + /// Associated types (e.g. function params, tuple elements). Unused for now. + pub associated: Vec, } /// Trait for converting a type list into a list of stack information. @@ -43,9 +65,11 @@ impl ToTypeIdList fn to_stack_info_list() -> Vec { let mut list = T::to_stack_info_list(); list.push(StackInfo { - stack_id: TypeId::of::(), - stack_unwind: |stack| unsafe { stack.drop::(Self::HEAD_PADDED) }, - padded: Self::HEAD_PADDED, + type_id: TypeId::of::(), + type_name: Cow::Borrowed(std::any::type_name::()), + padding: Self::HEAD_PADDED, + dropper: |stack| unsafe { stack.drop::(Self::HEAD_PADDED) }, + associated: Vec::new(), }); list } @@ -86,6 +110,8 @@ type Dropper = fn(&mut RawStack); pub struct DynSegment { pub(crate) segment: RawSegment, pub(crate) argument_ids: Vec, + /// Type names for each argument slot, for error reporting (parallel to `argument_ids`). + pub(crate) argument_names: Vec>, pub(crate) stack_ids: Vec, stack_index: usize, } @@ -100,7 +126,8 @@ impl DynSegment { let stack_ids = ReverseList::::to_stack_info_list(); DynSegment { segment: RawSegment::new(), - argument_ids: stack_ids.iter().map(|s| s.stack_id).collect(), + argument_ids: stack_ids.iter().map(|s| s.type_id).collect(), + argument_names: stack_ids.iter().map(|s| s.type_name.clone()).collect(), stack_ids, stack_index: size_of::>(), } @@ -112,8 +139,9 @@ impl DynSegment { pub fn new_fragment(&self) -> Self { DynSegment { segment: RawSegment::new(), - argument_ids: Vec::::new(), // should be optional? - stack_ids: Vec::::new(), + argument_ids: Vec::new(), + argument_names: Vec::new(), + stack_ids: Vec::new(), stack_index: self.stack_index, } } @@ -135,7 +163,7 @@ impl DynSegment { ); let start = self.stack_ids.len() - L::LENGTH; ensure!( - TypeIdIterator::::new().eq(self.stack_ids[start..].iter().map(|info| info.stack_id)), + TypeIdIterator::::new().eq(self.stack_ids[start..].iter().map(|info| info.type_id)), "stack type ids do not match" ); self.stack_ids.truncate(start); @@ -151,13 +179,15 @@ impl DynSegment { let padded = aligned_index != self.stack_index; self.stack_ids.push(StackInfo { - stack_id: TypeId::of::(), - stack_unwind: if padded { + type_id: TypeId::of::(), + type_name: Cow::Borrowed(std::any::type_name::()), + padding: padded, + dropper: if padded { |stack| unsafe { stack.drop::(true) } } else { |stack| unsafe { stack.drop::(false) } }, - padded, + associated: Vec::new(), }); self.stack_index = aligned_index + size_of::(); } @@ -166,11 +196,46 @@ impl DynSegment { let mut result = [false; N]; let start = self.stack_ids.len().saturating_sub(N); for (i, info) in self.stack_ids[start..].iter().enumerate() { - result[i] = info.padded; + result[i] = info.padding; } result } + /// Captures the current stack droppers for use when unwinding on error. + fn capture_unwind(&self) -> Vec { + self.stack_ids.iter().map(|info| info.dropper).collect() + } + + /// Runs the captured droppers in reverse order on error, then propagates the error. + fn unwind_on_err( + unwind: &[Dropper], + stack: &mut RawStack, + result: Result, + ) -> Result { + match result { + Ok(r) => Ok(r), + Err(e) => { + for dropper in unwind.iter().rev() { + dropper(stack); + } + Err(e) + } + } + } + + /// Returns a slice of the top N [`StackInfo`] entries (stack order: oldest first in the slice). + /// + /// Use this for operation lookup so errors can report type names. Returns an empty slice + /// if `n` is 0 or greater than the current stack size. + #[must_use] + pub fn peek_stack_infos(&self, n: usize) -> &[StackInfo] { + if n > self.stack_ids.len() { + return &[]; + } + let start = self.stack_ids.len() - n; + &self.stack_ids[start..] + } + /// Pushes a nullary operation that takes no arguments and returns a value of type R. /// /// The return type is tracked in the type stack for subsequent operations. @@ -201,21 +266,62 @@ impl DynSegment { F: Fn() -> anyhow::Result + 'static, R: 'static, { - let unwind: Vec<_> = self - .stack_ids - .iter() - .map(|info| info.stack_unwind) - .collect(); - self.segment.raw0(move |stack| match op() { - Ok(r) => Ok(r), - Err(e) => { - for dropper in unwind.iter().rev() { - dropper(stack); - } - Err(e) - } - }); + let unwind = self.capture_unwind(); + self.segment + .raw0(move |stack| Self::unwind_on_err(&unwind, stack, op())); + self.push_type::(); + } + + /// Pushes a unary operation that takes one argument of type `T` and returns a `Result`. + /// + /// If the operation succeeds, the result is pushed onto the stack. If it fails, + /// the stack is unwound to its previous state and the error is propagated. + /// + /// # Errors + /// + /// Returns an error if the argument type does not match the expected type. + pub fn op1r(&mut self, op: F) -> Result<()> + where + F: Fn(T) -> anyhow::Result + 'static, + T: 'static, + R: 'static, + { + let [p0] = self.get_last_n_padded::<1>(); + self.pop_types::<(T, ())>()?; + let unwind = self.capture_unwind(); + self.segment.raw1( + move |stack, t| Self::unwind_on_err(&unwind, stack, op(t)), + p0, + ); + self.push_type::(); + Ok(()) + } + + /// Pushes a binary operation that takes two arguments of types `T` and `U` and returns a `Result`. + /// + /// If the operation succeeds, the result is pushed onto the stack. If it fails, + /// the stack is unwound to its previous state and the error is propagated. + /// + /// # Errors + /// + /// Returns an error if the argument types do not match the expected types. + pub fn op2r(&mut self, op: F) -> Result<()> + where + F: Fn(T, U) -> anyhow::Result + 'static, + T: 'static, + U: 'static, + R: 'static, + { + let [p0, p1] = self.get_last_n_padded::<2>(); + self.pop_types::<(T, (U, ()))>()?; + let unwind = self.capture_unwind(); + self.segment.raw2( + move |stack, t, u| Self::unwind_on_err(&unwind, stack, op(t, u)), + p0, + p1, + ); self.push_type::(); + Ok(()) } /// Pushes a value to the stack without any operations. @@ -333,7 +439,7 @@ impl DynSegment { fragment_1.stack_ids.len() ); ensure!( - fragment_0.stack_ids[0].stack_id == fragment_1.stack_ids[0].stack_id, + fragment_0.stack_ids[0].type_id == fragment_1.stack_ids[0].type_id, "fragment result types must match" ); @@ -419,10 +525,15 @@ impl DynSegment { )); } if self.argument_ids[0] != TypeId::of::() { + let got = self + .argument_names + .first() + .map(Cow::as_ref) + .unwrap_or("?"); return Err(anyhow::anyhow!( "argument type mismatch: expected {}, got {}", std::any::type_name::(), - std::any::type_name::() // TODO: Need to store type names along with TypeId + got )); } self.pop_types::<(R, ())>()?; @@ -475,6 +586,64 @@ mod tests { Ok(()) } + #[test] + fn op1r_success() -> Result<(), anyhow::Error> { + let mut segment = DynSegment::new::<()>(); + segment.op0(|| 21u32); + segment.op1r(|n: u32| Ok::<_, anyhow::Error>(n * 2))?; + let result: u32 = segment.call0()?; + assert_eq!(result, 42); + Ok(()) + } + + #[test] + fn op1r_error_unwinds() -> Result<(), anyhow::Error> { + let mut segment = DynSegment::new::<()>(); + let drop_count = Arc::new(AtomicUsize::new(0)); + let tracker = DropCounter(drop_count.clone()); + segment.op0(move || tracker.clone()); + segment.op0(|| 7u32); + segment.op1r(|_n: u32| -> Result { Err(anyhow::anyhow!("op1r error")) })?; + segment.op1(|_: DropCounter| 0u32)?; + segment.op2(|_: DropCounter, x: u32| x)?; // consume to single u32 for call0 + let result = segment.call0::(); + assert!(result.is_err(), "expected Err, got {:?}", result); + assert_eq!(result.unwrap_err().to_string(), "op1r error"); + // DropCounter (under the u32) was unwound when op1r failed. + assert_eq!(drop_count.load(Ordering::SeqCst), 1); + Ok(()) + } + + #[test] + fn op2r_success() -> Result<(), anyhow::Error> { + let mut segment = DynSegment::new::<()>(); + segment.op0(|| 10u32); + segment.op0(|| 32u32); + segment.op2r(|a: u32, b: u32| Ok::<_, anyhow::Error>(a + b))?; + let result: u32 = segment.call0()?; + assert_eq!(result, 42); + Ok(()) + } + + #[test] + fn op2r_error_unwinds() -> Result<(), anyhow::Error> { + let mut segment = DynSegment::new::<()>(); + let drop_count = Arc::new(AtomicUsize::new(0)); + let tracker = DropCounter(drop_count.clone()); + segment.op0(move || tracker.clone()); + segment.op0(|| 7u32); + segment.op0(|| 8u32); + segment.op2r(|_a: u32, _b: u32| -> Result { Err(anyhow::anyhow!("op2r error")) })?; + segment.op1(|_: DropCounter| 0u32)?; + segment.op2(|_: DropCounter, x: u32| x)?; // consume to single u32 for call0 + let result = segment.call0::(); + assert!(result.is_err(), "expected Err, got {:?}", result); + assert_eq!(result.unwrap_err().to_string(), "op2r error"); + // DropCounter (under the two u32s) was unwound when op2r failed. + assert_eq!(drop_count.load(Ordering::SeqCst), 1); + Ok(()) + } + #[test] fn segment_operations() -> Result<(), anyhow::Error> { let mut operations = DynSegment::new::<()>(); diff --git a/cel-runtime/src/lib.rs b/cel-runtime/src/lib.rs index 13083bf..7a288f1 100644 --- a/cel-runtime/src/lib.rs +++ b/cel-runtime/src/lib.rs @@ -33,6 +33,7 @@ //! ``` #![warn(missing_docs)] + /// Compile-time stack list implementation for type-safe stack operations. pub mod c_stack_list; /// Dynamic segment implementation with runtime type checking. @@ -53,6 +54,8 @@ pub mod raw_vec; pub mod segment; /// Tuple list implementation for type-safe tuple operations. pub mod tuple_list; +/// Recursive descent parser for CEL expressions. +pub mod parser; pub use c_stack_list::*; pub use dyn_segment::*; @@ -64,3 +67,16 @@ pub use raw_stack::*; pub use raw_vec::*; pub use segment::*; //pub use tuple_list::*; + +pub use parser::op_table::OpLookup; +pub use parser::CELParser; +pub use parser::CELError; + +impl std::str::FromStr for DynSegment { + type Err = parser::CELError; + + fn from_str(s: &str) -> std::result::Result { + let mut parser = parser::CELParser::new(OpLookup::new()); + parser.parse_str(s) + } +} diff --git a/cel-runtime/src/parser/error.rs b/cel-runtime/src/parser/error.rs new file mode 100644 index 0000000..ab3357a --- /dev/null +++ b/cel-runtime/src/parser/error.rs @@ -0,0 +1,190 @@ +//! Parse error type with message and source span for CEL. +//! +//! Uses a [`SourceSpan`] (line/column only) so errors are `Send + Sync` and can +//! be used from async execution. Use [`SourceSpan::from_proc_macro2`] to extract +//! location from a `proc_macro2::Span` when building errors in the parser. + +use owo_colors::OwoColorize; +use proc_macro2::LineColumn; + +/// Source region as start/end line and column. +/// +/// Uses [`proc_macro2::LineColumn`] for positions (1-based line, 0-based column). +/// This type is `Send + Sync`. Build it from a `proc_macro2::Span` via +/// [`SourceSpan::from_proc_macro2`] when you have one (e.g. in the parser). +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct SourceSpan { + /// Start position (inclusive). + pub start: LineColumn, + /// End position (inclusive). + pub end: LineColumn, +} + +impl Default for SourceSpan { + fn default() -> Self { + SourceSpan { + start: LineColumn { line: 0, column: 0 }, + end: LineColumn { line: 0, column: 0 }, + } + } +} + +impl SourceSpan { + /// Builds a span from raw line/column values. + /// + /// Lines are 1-based, columns are 0-based (matching [`proc_macro2::LineColumn`]). + pub fn new( + start_line: usize, + start_column: usize, + end_line: usize, + end_column: usize, + ) -> Self { + SourceSpan { + start: LineColumn { + line: start_line, + column: start_column, + }, + end: LineColumn { + line: end_line, + column: end_column, + }, + } + } + + /// Extracts start/end line and column from a `proc_macro2::Span`. + /// + /// Use this when creating errors in the parser or other code that has a + /// `proc_macro2::Span`; the result is `Send + Sync` and can be stored in + /// [`CELError`] for use from async or other threads. + pub fn from_proc_macro2(span: proc_macro2::Span) -> Self { + SourceSpan { + start: span.start(), + end: span.end(), + } + } +} + +/// A CEL parse error with a message and source location. +/// +/// Uses a [`SourceSpan`] (line/column only) so the error is `Send + Sync` and +/// can be used from async execution or reported across thread boundaries. +#[derive(Clone, Debug)] +pub struct CELError { + message: String, + span: SourceSpan, +} + +impl CELError { + /// Creates a new error with the given message and source span. + pub fn new(message: impl Into, span: SourceSpan) -> Self { + CELError { + message: message.into(), + span, + } + } + + /// Creates a new error from a message and a `proc_macro2::Span`. + /// + /// Extracts line/column from the span so the resulting error is `Send + Sync`. + pub fn with_proc_macro_span(message: impl Into, span: proc_macro2::Span) -> Self { + CELError::new(message, SourceSpan::from_proc_macro2(span)) + } + + /// Returns the error message. + pub fn message(&self) -> &str { + &self.message + } + + /// Returns the source span for this error. + pub fn span(&self) -> SourceSpan { + self.span + } + + /// Formats this error in rustc diagnostic style with source context. + /// + /// Produces a multi-line string similar to Rust compiler diagnostics, + /// including the source file location, error message, and a caret + /// indicating the error position. + /// + /// # Arguments + /// + /// * `source_code` - The original source code being parsed + /// * `filename` - The name of the file (for display) + /// * `start_line` - The starting line number in the original file (1-based) + /// + /// # References + /// + /// See the [rustc diagnostic formatting guide](https://github.com/rust-lang/rustc-dev-guide/blob/master/src/diagnostics.md). + pub fn format_rustc_style( + &self, + source_code: &str, + filename: &str, + start_line: u32, + ) -> String { + let start = self.span.start; + let end = self.span.end; + let lines: Vec<&str> = source_code.lines().collect(); + + let mut output = String::new(); + let error_line = start_line + (start.line as u32) - 1; + let error_column = start.column + 1; + let max_line_num = start_line + (end.line as u32) - 1; + let line_width = max_line_num.to_string().len(); + + output.push_str(&format!("{}: {}\n", "error".red().bold(), self.message)); + output.push_str(&format!( + " {} {}:{}:{}\n", + "-->".blue().bold(), + filename.blue(), + error_line.to_string().blue(), + error_column.to_string().blue() + )); + output.push_str(&format!( + "{:width$} {}\n", + "", + "|".blue().bold(), + width = line_width + )); + + for line_num in start.line..=end.line { + if let Some(line_content) = lines.get(line_num.saturating_sub(1)) { + let display_line_num = start_line + (line_num as u32) - 1; + output.push_str(&format!( + "{} {} {}\n", + display_line_num.to_string().blue().bold(), + "|".blue().bold(), + line_content + )); + + if line_num == start.line { + output.push_str(&format!( + "{:width$} {} ", + "", + "|".blue().bold(), + width = line_width + )); + output.push_str(&" ".repeat(start.column)); + let caret_len = if start.line == end.line { + end.column.saturating_sub(start.column).max(1) + } else { + line_content + .len() + .saturating_sub(start.column.saturating_sub(1)) + }; + output.push_str(&"^".repeat(caret_len).red().bold().to_string()); + output.push('\n'); + } + } + } + + output + } +} + +impl std::fmt::Display for CELError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.message) + } +} + +impl std::error::Error for CELError {} diff --git a/cel-runtime/src/parser/lex_lexer.rs b/cel-runtime/src/parser/lex_lexer.rs new file mode 100644 index 0000000..0158940 --- /dev/null +++ b/cel-runtime/src/parser/lex_lexer.rs @@ -0,0 +1,533 @@ +//! A `lex-lexer` is a lexer taking a lex token stream and returning a token stream of a different +//! type. Initially this is used to convert the TokenTree from Rust's proc_macro into a higher level +//! token stream. The goal, however, is to be able to specify with a grammar how to process a token +//! stream. +//! +//! # Error Handling +//! +//! This lexer does not produce errors. All input `TokenTree` items come pre-validated from +//! `proc_macro2`, which has already verified correct Rust lexical syntax. The lexer only +//! transforms and flattens tokens, operations that cannot fail on valid input. Any impossible +//! states (like receiving a `Punct` or `Group` in `convert_token`) use `unreachable!()` since +//! they represent programming errors, not malformed input. + +use proc_macro2::{Delimiter, Ident, Spacing, Span, TokenTree}; +use syn::Lit; + +/// A trait for token types that provides access to span information for error reporting. +/// +/// This trait is minimal by design - token type discrimination is done through pattern +/// matching on enum variants, not through trait methods. This keeps the trait simple +/// and allows different token types to have their own specific fields and methods. +pub trait HasSpan { + /// Get the span for error reporting. + fn span(&self) -> Span; +} + +/// A group iterator with its associated close delimiter information. +struct GroupLevel { + iter: TokenStreamIter, + delimiter: Delimiter, + span: Span, +} + +/// A lexer that transforms a `TokenTree` stream into a flattened `Token` stream. +/// +/// Groups are flattened into OpenDelim and CloseDelim tokens, and literals are +/// eagerly discriminated into specific types. Flattening is lazy - group iterators +/// are pushed onto a stack and processed one token at a time. +/// +/// Iterator type used for token streams (from `TokenStream::into_iter()`). +pub type TokenStreamIter = proc_macro2::token_stream::IntoIter; + +/// Punctuation operator (1 or 2 chars) without heap allocation. +#[derive(Clone, Debug)] +pub enum PunctOp { + /// Single character (e.g. `+`, `-`). + One(char), + /// Two characters (e.g. `&&`, `<=`). + Two([char; 2]), +} + +impl PartialEq for PunctOp { + fn eq(&self, other: &str) -> bool { + match self { + PunctOp::One(c) => other.len() == 1 && other.chars().next() == Some(*c), + PunctOp::Two([a, b]) => { + let mut it = other.chars(); + it.next() == Some(*a) && it.next() == Some(*b) && it.next().is_none() + } + } + } +} + +impl PartialEq<&str> for PunctOp { + fn eq(&self, other: &&str) -> bool { + self.eq(*other) + } +} + +/// Multi-character operators are combined at this level (e.g., `&` + `&` -> `&&`). +pub struct LexLexer { + input: TokenStreamIter, + /// Stack of iterators for nested groups - allows lazy flattening. + /// Each entry tracks the iterator and its close delimiter info. + group_stack: Vec, + /// Pending close delimiter to emit when we've just exhausted a group iterator. + pending_close: Option<(Delimiter, Span)>, + /// Pending token that was consumed while looking ahead. + pending_token: Option, +} + +impl LexLexer { + /// Creates a new lexer from a token tree iterator. + /// + /// # Arguments + /// + /// * `input` - An iterator over `TokenTree` items to be lexed into `Token`s + pub fn new(input: TokenStreamIter) -> Self { + Self { + input, + group_stack: Vec::new(), + pending_close: None, + pending_token: None, + } + } + + /// Converts a single TokenTree into a Token (except Punct and Group which are handled specially). + /// + /// This handles Literal and Identifier tokens. Boolean identifiers (`true`, `false`) are + /// converted to Boolean literals. Punct tokens need special handling for combining + /// multi-char operators, and Groups are handled by the iterator. + fn convert_token(token: TokenTree) -> Token { + match token { + TokenTree::Literal(lit) => { + // Wrap in TokenTree and convert to TokenStream to preserve span information + let token_stream: proc_macro2::TokenStream = TokenTree::Literal(lit).into(); + let syn_lit: Lit = syn::parse2(token_stream) + .expect("proc_macro2 Literal should parse as syn::Lit"); + + // Verbatim literals should never occur when parsing proc_macro2::Literal + debug_assert!( + !matches!(syn_lit, Lit::Verbatim(_)), + "Unexpected Verbatim literal from proc_macro2::Literal" + ); + + Token::Literal(syn_lit) + } + TokenTree::Ident(ident) => { + let ident_str = ident.to_string(); + + // Check if this is a boolean literal + match ident_str.as_str() { + "true" | "false" => { + // Wrap in TokenTree and convert to TokenStream to preserve span information + let token_stream: proc_macro2::TokenStream = TokenTree::Ident(ident).into(); + let syn_lit: Lit = syn::parse2(token_stream) + .expect("boolean identifier should parse as syn::Lit::Bool"); + Token::Literal(syn_lit) + } + _ => Token::Identifier(ident), + } + } + TokenTree::Punct(_) | TokenTree::Group(_) => { + // These should be handled by the iterator + unreachable!("Punct and Group tokens should be handled by the iterator, not convert_token") + } + } + } + + /// Check if two characters form a known multi-character operator. + fn is_compound_operator(first: char, second: char) -> bool { + matches!( + (first, second), + ('&', '&') | ('|', '|') | ('=', '=') | ('!', '=') | + ('<', '=') | ('>', '=') | ('<', '<') | ('>', '>') + ) + } + + /// Get the next TokenTree from the current iterator (top of stack or main input). + /// Returns None and sets pending_close when an iterator is exhausted. + fn next_token_tree(&mut self) -> Option { + // Check if we have a pending token from lookahead + if let Some(token) = self.pending_token.take() { + return Some(token); + } + + // Try to get from the top of the group stack first + if let Some(level) = self.group_stack.last_mut() { + if let Some(tt) = level.iter.next() { + return Some(tt); + } + // Current iterator exhausted, pop it and set pending close + let level = self.group_stack.pop().unwrap(); + self.pending_close = Some((level.delimiter, level.span)); + return None; // Signal that we need to emit close delimiter + } + + // Stack is empty, get from main input + self.input.next() + } +} + +/// A parsed literal value using syn's Lit enum. +/// +/// This is a simple wrapper around syn's `Lit` type that includes boolean literals +/// even though they appear as identifiers in proc_macro2 (converted during lexing). +pub type Literal = Lit; + +impl HasSpan for Literal { + fn span(&self) -> Span { + match self { + Lit::Str(lit) => lit.span(), + Lit::ByteStr(lit) => lit.span(), + Lit::CStr(lit) => lit.span(), + Lit::Byte(lit) => lit.span(), + Lit::Char(lit) => lit.span(), + Lit::Int(lit) => lit.span(), + Lit::Float(lit) => lit.span(), + Lit::Bool(lit) => lit.span(), + Lit::Verbatim(_) => unreachable!("Verbatim literals should never occur"), + _ => Span::call_site(), + } + } +} + +/// A flattened token that represents elements from a TokenTree stream. +/// +/// Groups are flattened into OpenDelim and CloseDelim tokens, making parsing +/// simpler by removing nesting from the token stream. +#[derive(Debug)] +pub enum Token { + /// A literal value (integer, string, boolean, or float) with eager discrimination. + Literal(Literal), + + /// An identifier. + Identifier(Ident), + + /// A punctuation operator (single or multi-character; no heap for 1–2 chars). + Punct { + /// The operator (e.g., "+", "&&", "<="). + op: PunctOp, + /// Span for error reporting. + span: Span, + }, + + /// Opening delimiter (flattened from Group). + OpenDelim { + /// The type of delimiter (Parenthesis, Brace, Bracket). + delimiter: Delimiter, + /// Span for error reporting. + span: Span, + }, + + /// Closing delimiter (flattened from Group). + CloseDelim { + /// The type of delimiter (Parenthesis, Brace, Bracket). + delimiter: Delimiter, + /// Span for error reporting. + span: Span, + }, +} + +impl HasSpan for Token { + fn span(&self) -> Span { + match self { + Token::Literal(lit) => lit.span(), + Token::Identifier(ident) => ident.span(), + Token::Punct { span, .. } => *span, + Token::OpenDelim { span, .. } => *span, + Token::CloseDelim { span, .. } => *span, + } + } +} + +impl HasSpan for TokenTree { + fn span(&self) -> Span { + match self { + TokenTree::Group(g) => g.span(), + TokenTree::Ident(i) => i.span(), + TokenTree::Punct(p) => p.span(), + TokenTree::Literal(l) => l.span(), + } + } +} + +impl Iterator for LexLexer { + type Item = Token; + + fn next(&mut self) -> Option { + // Check if we have a pending close delimiter to emit + if let Some((delimiter, span)) = self.pending_close.take() { + return Some(Token::CloseDelim { delimiter, span }); + } + + // Get next token tree from current iterator + let token = match self.next_token_tree() { + Some(tt) => tt, + None if self.pending_close.is_some() => { + // An iterator was exhausted and pending_close was set + // Emit the close delimiter on the next call (recursive call) + return self.next(); + } + None => { + // All iterators exhausted, no more tokens + return None; + } + }; + + // Handle Groups by pushing their iterator onto the stack + if let TokenTree::Group(group) = token { + let delimiter = group.delimiter(); + let span = group.span(); + + // Push the group's iterator and close info onto the stack + self.group_stack.push(GroupLevel { + iter: group.stream().into_iter(), + delimiter, + span, + }); + + // Return OpenDelim immediately + return Some(Token::OpenDelim { delimiter, span }); + } + + // Handle Punct tokens with potential combining + if let TokenTree::Punct(punct) = token { + let ch = punct.as_char(); + let spacing = punct.spacing(); + let span = punct.span(); + + // If spacing is Joint, try to combine with next punct + if spacing == Spacing::Joint { + // Get next token to see if we can combine + match self.next_token_tree() { + Some(TokenTree::Punct(next_punct)) => { + let next_ch = next_punct.as_char(); + + // Check if they form a compound operator + if Self::is_compound_operator(ch, next_ch) { + return Some(Token::Punct { + op: PunctOp::Two([ch, next_ch]), + span, + }); + } else { + self.pending_token = Some(TokenTree::Punct(next_punct)); + return Some(Token::Punct { + op: PunctOp::One(ch), + span, + }); + } + } + Some(other_token) => { + self.pending_token = Some(other_token); + return Some(Token::Punct { + op: PunctOp::One(ch), + span, + }); + } + None => { + return Some(Token::Punct { + op: PunctOp::One(ch), + span, + }); + } + } + } else { + return Some(Token::Punct { + op: PunctOp::One(ch), + span, + }); + } + } + + // Not a group or punct, convert directly (Literal or Ident) + Some(Self::convert_token(token)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::TokenStream; + use std::str::FromStr; + + #[test] + fn test_literal_integer() { + let input = TokenStream::from_str("42").unwrap(); + let mut lexer = LexLexer::new(input.into_iter()); + + let token = lexer.next().unwrap(); + match token { + Token::Literal(Lit::Int(..)) => {} + _ => panic!("Expected integer literal, got {:?}", token), + } + } + + #[test] + fn test_literal_string() { + let input = TokenStream::from_str(r#""hello""#).unwrap(); + let mut lexer = LexLexer::new(input.into_iter()); + + let token = lexer.next().unwrap(); + match token { + Token::Literal(Lit::Str(..)) => {} + _ => panic!("Expected string literal"), + } + } + + #[test] + fn test_literal_boolean() { + // Test 'true' boolean literal + let input = TokenStream::from_str("true").unwrap(); + let mut lexer = LexLexer::new(input.into_iter()); + + let token = lexer.next().unwrap(); + match token { + Token::Literal(Lit::Bool(lit_bool)) => { + assert_eq!(lit_bool.value, true); + } + _ => panic!("Expected boolean literal for 'true', got {:?}", token), + } + + // Test 'false' boolean literal + let input = TokenStream::from_str("false").unwrap(); + let mut lexer = LexLexer::new(input.into_iter()); + + let token = lexer.next().unwrap(); + match token { + Token::Literal(Lit::Bool(lit_bool)) => { + assert_eq!(lit_bool.value, false); + } + _ => panic!("Expected boolean literal for 'false', got {:?}", token), + } + } + + #[test] + fn test_literal_float() { + let input = TokenStream::from_str("3.14").unwrap(); + let mut lexer = LexLexer::new(input.into_iter()); + + let token = lexer.next().unwrap(); + match token { + Token::Literal(Lit::Float(..)) => {} + _ => panic!("Expected float literal"), + } + } + + #[test] + fn test_identifier() { + let input = TokenStream::from_str("foo").unwrap(); + let mut lexer = LexLexer::new(input.into_iter()); + + let token = lexer.next().unwrap(); + match token { + Token::Identifier(ident) => { + assert_eq!(ident.to_string(), "foo"); + } + _ => panic!("Expected identifier"), + } + } + + #[test] + fn test_punct() { + let input = TokenStream::from_str("+").unwrap(); + let mut lexer = LexLexer::new(input.into_iter()); + + let token = lexer.next().unwrap(); + match token { + Token::Punct { op, .. } => { + assert!(op == "+"); + } + _ => panic!("Expected punctuation"), + } + } + + #[test] + fn test_compound_operator() { + let input = TokenStream::from_str("a && b").unwrap(); + let lexer = LexLexer::new(input.into_iter()); + + let tokens: Vec<_> = lexer.collect(); + assert_eq!(tokens.len(), 3); + + match &tokens[1] { + Token::Punct { op, .. } => { + assert!(op == "&&"); + } + _ => panic!("Expected && operator"), + } + } + + #[test] + fn test_group_flattening() { + let input = TokenStream::from_str("(10 + 20)").unwrap(); + let lexer = LexLexer::new(input.into_iter()); + + // Should get: OpenDelim, Integer, Punct, Integer, CloseDelim + let tokens: Vec<_> = lexer.collect(); + assert_eq!(tokens.len(), 5); + + assert!(matches!(tokens[0], Token::OpenDelim { delimiter: Delimiter::Parenthesis, .. })); + assert!(matches!(tokens[1], Token::Literal(Lit::Int(..)))); + assert!(matches!(&tokens[2], Token::Punct { op, .. } if op == "+")); + assert!(matches!(tokens[3], Token::Literal(Lit::Int(..)))); + assert!(matches!(tokens[4], Token::CloseDelim { delimiter: Delimiter::Parenthesis, .. })); + } + + #[test] + fn test_nested_groups() { + let input = TokenStream::from_str("(10 + (20 * 30))").unwrap(); + let lexer = LexLexer::new(input.into_iter()); + + let tokens: Vec<_> = lexer.collect(); + + // Should have: OpenDelim, 10, +, OpenDelim, 20, *, 30, CloseDelim, CloseDelim + assert_eq!(tokens.len(), 9); + + // Verify structure + assert!(matches!(tokens[0], Token::OpenDelim { .. })); + assert!(matches!(tokens[1], Token::Literal(Lit::Int(..)))); + assert!(matches!(&tokens[2], Token::Punct { op, .. } if op == "+")); + assert!(matches!(tokens[3], Token::OpenDelim { .. })); + assert!(matches!(tokens[4], Token::Literal(Lit::Int(..)))); + assert!(matches!(&tokens[5], Token::Punct { op, .. } if op == "*")); + assert!(matches!(tokens[6], Token::Literal(Lit::Int(..)))); + assert!(matches!(tokens[7], Token::CloseDelim { .. })); + assert!(matches!(tokens[8], Token::CloseDelim { .. })); + } + + #[test] + fn test_span_preservation() { + let input = TokenStream::from_str("foo").unwrap(); + let mut lexer = LexLexer::new(input.into_iter()); + + let token = lexer.next().unwrap(); + + // HasSpan trait should provide span + let span = HasSpan::span(&token); + assert!(!span.source_text().unwrap_or_default().is_empty()); + } + + #[test] + fn test_haspan_trait_for_tokentree() { + let input = TokenStream::from_str("42").unwrap(); + let tt = input.into_iter().next().unwrap(); + + // TokenTree implements HasSpan trait + let _span = HasSpan::span(&tt); + } + + #[test] + fn test_mixed_tokens() { + let input = TokenStream::from_str("foo + 42").unwrap(); + let lexer = LexLexer::new(input.into_iter()); + + let tokens: Vec<_> = lexer.collect(); + assert_eq!(tokens.len(), 3); + + assert!(matches!(tokens[0], Token::Identifier(_))); + assert!(matches!(&tokens[1], Token::Punct { op, .. } if op == "+")); + assert!(matches!(tokens[2], Token::Literal(Lit::Int(..)))); + } +} diff --git a/cel-runtime/src/parser/mod.rs b/cel-runtime/src/parser/mod.rs new file mode 100644 index 0000000..5550294 --- /dev/null +++ b/cel-runtime/src/parser/mod.rs @@ -0,0 +1,1300 @@ +#![warn(missing_docs)] + +//! A recursive descent parser for CEL (Common Expression Language) expressions. +//! +//! This crate provides a parser that can parse CEL expressions into executable segments. +//! The parser follows the CEL grammar specification and provides detailed error reporting +//! with source location information. +//! +//! # Error Handling +//! +//! Parse errors are returned as [`CELError`], which carries a message and source span for diagnostics. +//! All errors result from malformed input (syntax errors, type mismatches, undefined identifiers). +//! +//! # Grammar +//! +//! ```text +//! expression = or_expression ?eos?. +//! or_expression = and_expression { "||" and_expression }. +//! and_expression = comparison_expression { "&&" comparison_expression }. +//! comparison_expression = bitwise_or_expression +//! [ ("==" | "!=" | "<" | ">" | "<=" | ">=") bitwise_or_expression ]. +//! bitwise_or_expression = bitwise_xor_expression { "|" bitwise_xor_expression }. +//! bitwise_xor_expression = bitwise_and_expression { "^" bitwise_and_expression }. +//! bitwise_and_expression = bitwise_shift_expression { "&" bitwise_shift_expression }. +//! bitwise_shift_expression = additive_expression { ("<<" | ">>") additive_expression }. +//! additive_expression = multiplicative_expression { ("+" | "-") multiplicative_expression }. +//! multiplicative_expression = unary_expression { ("*" | "/" | "%") unary_expression }. +//! unary_expression = (("-" | "!") unary_expression) | postfix_expression. +//! postfix_expression = primary_expression { "(" parameter_list ")" }. +//! primary_expression = literal | identifier | "(" or_expression ")". +//! parameter_list = [ or_expression { "," or_expression } ]. +//! ``` +//! +//! # Note + +//! `?eos?` denotes end of stream. +//! +//! # Examples +//! +//! ```rust +//! use cel_runtime::DynSegment; +//! use std::str::FromStr; +//! +//! let mut segment: DynSegment = "10u32 + 20u32 * 5u32".parse().unwrap(); +//! let result = segment.call0::(); +//! assert!(result.is_ok()); +//! assert_eq!(result.unwrap(), 110); // 10 + 20 * 5 = 10 + 100 +//! ``` +//! +//! ## Basic Usage +//! +//! ```rust +//! use cel_runtime::parser::CELParser; +//! use cel_runtime::OpLookup; +//! use proc_macro2::TokenStream; +//! use std::str::FromStr; +//! +//! let input = TokenStream::from_str("10").unwrap(); +//! let mut parser = CELParser::new(OpLookup::new()); +//! parser.set_tokens(input.into_iter()); +//! let result = parser.is_expression(); +//! assert!(result.is_ok()); +//! ``` +//! +//! ## Error Formatting +//! +//! ```rust +//! use cel_runtime::parser::CELParser; +//! use cel_runtime::OpLookup; +//! use proc_macro2::TokenStream; +//! use std::str::FromStr; +//! +//! let line = line!() + 1; +//! let source = r#" +//! 10 20 +//! "#; // Invalid: missing operator +//! let input = TokenStream::from_str(source).unwrap(); +//! let mut parser = CELParser::new(OpLookup::new()); +//! parser.set_tokens(input.into_iter()); +//! +//! if let Err(e) = parser.is_expression() { +//! // Format error starting at line 1 +//! println!("{}", e.format_rustc_style(source, file!(), line)); +//! // Output: +//! // error: unexpected token +//! // --> example.cel:1:4 +//! // | +//! // 1 | 10 20 +//! // | ^^ +//! } +//! ``` + +mod error; +mod lex_lexer; +pub mod op_table; + +pub use error::{CELError, SourceSpan}; +pub use proc_macro2::LineColumn; + +use lex_lexer::{LexLexer, Literal as CelLiteral, Token, TokenStreamIter}; +use op_table::OpLookup; + +use crate::DynSegment; +use proc_macro2::{Delimiter, Ident, Literal, Span, TokenStream}; +use std::iter::Peekable; +use std::str::FromStr; + +/// Parser result type. +pub type Result = std::result::Result; + +fn push_literal(output: &mut DynSegment, lit: CelLiteral) { + match lit { + CelLiteral::Int(integer) => { + // Use syn's suffix() to determine the type + match integer.suffix() { + "u8" => output.just( + integer + .base10_parse::() + .expect("failed to parse u8 literal"), + ), + "u16" => output.just( + integer + .base10_parse::() + .expect("failed to parse u16 literal"), + ), + "u32" => output.just( + integer + .base10_parse::() + .expect("failed to parse u32 literal"), + ), + "u64" => output.just( + integer + .base10_parse::() + .expect("failed to parse u64 literal"), + ), + "u128" => output.just( + integer + .base10_parse::() + .expect("failed to parse u128 literal"), + ), + "usize" => output.just( + integer + .base10_parse::() + .expect("failed to parse usize literal"), + ), + "i8" => output.just( + integer + .base10_parse::() + .expect("failed to parse i8 literal"), + ), + "i16" => output.just( + integer + .base10_parse::() + .expect("failed to parse i16 literal"), + ), + "i64" => output.just( + integer + .base10_parse::() + .expect("failed to parse i64 literal"), + ), + "i128" => output.just( + integer + .base10_parse::() + .expect("failed to parse i128 literal"), + ), + "isize" => output.just( + integer + .base10_parse::() + .expect("failed to parse isize literal"), + ), + _ => { + // No suffix means i32 by default + output.just( + integer + .base10_parse::() + .expect("failed to parse i32 literal"), + ) + } + } + } + CelLiteral::Float(float) => { + // Use syn's suffix() to determine the type + match float.suffix() { + "f32" => output.just( + float + .base10_parse::() + .expect("failed to parse f32 literal"), + ), + _ => { + // No suffix or "f64" means f64 by default + output.just( + float + .base10_parse::() + .expect("failed to parse f64 literal"), + ) + } + } + } + CelLiteral::Str(string) => { + // Store the string value (without quotes) + output.just(string.value()); + } + CelLiteral::Bool(lit_bool) => { + // Push the boolean value directly + output.just(lit_bool.value); + } + CelLiteral::Char(ch) => { + // Push character literal + output.just(ch.value()); + } + CelLiteral::Byte(byte) => { + // Push byte literal (u8) + output.just(byte.value()); + } + CelLiteral::ByteStr(byte_str) => { + // Push byte string as Vec + output.just(byte_str.value()); + } + CelLiteral::CStr(c_str) => { + // Push C string directly + output.just(c_str.value()); + } + CelLiteral::Verbatim(_) => { + unreachable!("Verbatim literals should never occur") + } + _ => { + // Future literal types not yet handled + } + } +} + +/// A recursive descent parser for expressions. +/// +/// # Examples +/// +/// ## Basic Usage +/// +/// ```rust +/// use cel_runtime::OpLookup; +/// use cel_runtime::parser::CELParser; +/// use proc_macro2::TokenStream; +/// use std::str::FromStr; +/// +/// let input = TokenStream::from_str("10").unwrap(); +/// let mut parser = CELParser::new(OpLookup::new()); +/// parser.set_tokens(input.into_iter()); +/// let result = parser.is_expression(); +/// assert!(result.is_ok()); +/// ``` +/// +/// ## Error Formatting +/// +/// ```rust +/// use cel_runtime::OpLookup; +/// use cel_runtime::parser::CELParser; +/// use proc_macro2::TokenStream; +/// use std::str::FromStr; +/// +/// let line = line!() + 1; +/// let source = r#" +/// 10 + 20 30 +/// "#; // Invalid: missing operator +/// let input = TokenStream::from_str(source).unwrap(); +/// let mut parser = CELParser::new(OpLookup::new()); +/// parser.set_tokens(input.into_iter()); +/// +/// if let Err(e) = parser.is_expression() { +/// // Format error starting at line 1 +/// println!("{}", e.format_rustc_style(source, file!(), line)); +/// // Output: +/// // error: unexpected token +/// // --> example.cel:1:8 +/// // | +/// // 1 | 10 + 20 30 +/// // | ^^ +/// } +/// ``` +pub struct CELParser { + tokens: Option>, + context: DynSegment, + op_lookup: OpLookup, +} + +/// A primary expression representing the most basic expression types. +/// +/// Primary expressions are the atomic building blocks of CEL expressions, +/// consisting of either literal values or identifiers. +pub enum PrimaryExpression { + /// A literal value (integer, string, boolean, or float). + Literal(Literal), + /// An identifier referencing a variable or function. + Ident(Ident), +} + +/// Result type for parser probe operations. +/// +/// A `Probe` represents the outcome of attempting to parse a specific grammar +/// production without committing to the parse. This enables backtracking and +/// alternative parsing strategies. +pub enum Probe { + /// The probe did not match the expected grammar production. + NoMatch, + /// The probe matched but produced no value (e.g., optional production absent). + Match, + /// The probe matched and produced a value. + Value(T), +} + +/// A probe result for primary expression parsing. +pub type PrimaryProbe = Probe; + +impl CELParser { + /// Creates a new CEL parser with the given operation lookup. + /// + /// No tokens are set at construction; use [`set_tokens`](Self::set_tokens), + /// [`parse_tokens`](Self::parse_tokens), or [`parse_str`](Self::parse_str) to parse. + /// + /// # Arguments + /// + /// * `op_lookup` - Operation lookup for resolving operators and identifiers + pub fn new(op_lookup: OpLookup) -> Self { + CELParser { + tokens: None, + context: DynSegment::new::<()>(), + op_lookup, + } + } + + /// Sets the token stream for parsing. + /// + /// Call before [`is_expression`](Self::is_expression) or use [`parse_tokens`](Self::parse_tokens) + /// which sets tokens and parses in one step. + pub fn set_tokens(&mut self, tokens: TokenStreamIter) { + self.tokens = Some(LexLexer::new(tokens).peekable()); + self.context = DynSegment::new::<()>(); + } + + /// Parses a token stream into a [`DynSegment`]. + /// + /// Sets the token source, runs the expression grammar, and returns the segment on success. + /// + /// # Errors + /// + /// Returns an error if the input does not contain a valid CEL expression. + pub fn parse_tokens(&mut self, tokens: TokenStreamIter) -> Result { + self.set_tokens(tokens); + if !self.is_expression()? { + return Err(self.error_at("expression expected")); + } + Ok(std::mem::replace( + &mut self.context, + DynSegment::new::<()>(), + )) + } + + /// Parses a string into a [`DynSegment`]. + /// + /// Tokenizes the string then parses; equivalent to `parse_tokens(TokenStream::from_str(s)?.into_iter())`. + /// + /// # Errors + /// + /// Returns an error on lex failure or if the input does not contain a valid CEL expression. + pub fn parse_str(&mut self, s: &str) -> Result { + let input = TokenStream::from_str(s) + .map_err(|e| CELError::with_proc_macro_span(format!("lex: {}", e), e.span()))?; + self.parse_tokens(input.into_iter()) + } + + /// Returns a mutable reference to the operation lookup. + /// + /// This allows customization of the operations available during parsing, + /// such as adding new scopes for custom operations or identifiers. + /// + /// # Examples + /// + /// ```rust + /// use cel_runtime::parser::op_table::OpLookup; + /// use cel_runtime::parser::CELParser; + /// use cel_runtime::DynSegment; + /// use proc_macro2::TokenStream; + /// use std::any::TypeId; + /// use std::str::FromStr; + /// + /// let input = TokenStream::from_str("10 + 20").unwrap(); + /// let mut lookup = OpLookup::new(); + /// lookup.push_scope(|name, segment, num_operands| { + /// let matches = { + /// let top = segment.peek_stack_infos(num_operands); + /// name == "+" && top.len() == 2 && top[0].type_id == TypeId::of::() + /// }; + /// if matches { + /// segment.op2(|a: i32, b: i32| a + b + 1)?; // Custom addition + /// Ok(true) + /// } else { + /// Ok(false) + /// } + /// }); + /// let mut parser = CELParser::new(lookup); + /// parser.set_tokens(input.into_iter()); + /// ``` + pub fn op_lookup_mut(&mut self) -> &mut OpLookup { + &mut self.op_lookup + } + + fn advance(&mut self) { + self.tokens.as_mut().expect("tokens set").next(); + } + + /// Peeks at the current token without consuming it. + /// + /// Returns `None` if there are no more tokens. + fn peek_token(&mut self) -> Option<&Token> { + self.tokens.as_mut().expect("tokens set").peek() + } + + /// Builds a [`CELError`] at the current token's span (or call_site if no token). + fn error_at(&mut self, message: &str) -> CELError { + let span = match self.peek_token() { + Some(token) => { + use lex_lexer::HasSpan; + token.span() + } + None => Span::call_site(), + }; + CELError::new(message, SourceSpan::from_proc_macro2(span)) + } + + fn is_punctuation(&mut self, target: &str) -> bool { + // Simply check if the current token is a Punct with the target operator + match self.peek_token() { + Some(Token::Punct { op, .. }) if op == target => { + self.advance(); + true + } + _ => false, + } + } + + /// `expression = or_expression .` + pub fn is_expression(&mut self) -> Result { + if !self.is_or_expression()? { + return Ok(false); + } + if self.peek_token().is_some() { + return Err(self.error_at("unexpected token")); + } + Ok(true) + } + + /// `or_expression = and_expression { "||" and_expression }.` + fn is_or_expression(&mut self) -> Result { + if self.is_and_expression()? { + while self.is_punctuation("||") { + if !self.is_and_expression()? { + return Err(self.error_at("expected and_expression")); + } + if self.context.stack_ids.len() >= 2 + && let Err(e) = self.op_lookup.lookup("||", &mut self.context, 2) + { + return Err(self.error_at(&format!("operation error: {}", e))); + } + } + Ok(true) + } else { + Ok(false) + } + } + + /// `and_expression = comparison_expression { "&&" comparison_expression }.` + fn is_and_expression(&mut self) -> Result { + if self.is_comparison_expression()? { + while self.is_punctuation("&&") { + if !self.is_comparison_expression()? { + return Err(self.error_at("expected comparison_expression")); + } + if self.context.stack_ids.len() >= 2 + && let Err(e) = self.op_lookup.lookup("&&", &mut self.context, 2) + { + return Err(self.error_at(&format!("operation error: {}", e))); + } + } + Ok(true) + } else { + Ok(false) + } + } + + /// `comparison_expression = bitwise_or_expression [ ("==" | "!=" | "<" | ">" | "<=" | ">=") bitwise_or_expression ].` + fn is_comparison_expression(&mut self) -> Result { + if self.is_bitwise_or_expression()? { + // Check which operator we have (check longer operators first) + let op_name = if self.is_punctuation("==") { + Some("==") + } else if self.is_punctuation("!=") { + Some("!=") + } else if self.is_punctuation("<=") { + Some("<=") + } else if self.is_punctuation(">=") { + Some(">=") + } else if self.is_punctuation("<") { + Some("<") + } else if self.is_punctuation(">") { + Some(">") + } else { + None + }; + + if let Some(op_name) = op_name { + if !self.is_bitwise_or_expression()? { + return Err(self.error_at("expected bitwise_or_expression")); + } + if self.context.stack_ids.len() >= 2 + && let Err(e) = self.op_lookup.lookup(op_name, &mut self.context, 2) + { + return Err(self.error_at(&format!("operation error: {}", e))); + } + } + Ok(true) + } else { + Ok(false) + } + } + + /// `bitwise_or_expression = bitwise_xor_expression { "|" bitwise_xor_expression }.` + fn is_bitwise_or_expression(&mut self) -> Result { + if self.is_bitwise_xor_expression()? { + while self.is_punctuation("|") { + if !self.is_bitwise_xor_expression()? { + return Err(self.error_at("expected bitwise_xor_expression")); + } + if self.context.stack_ids.len() >= 2 + && let Err(e) = self.op_lookup.lookup("|", &mut self.context, 2) + { + return Err(self.error_at(&format!("operation error: {}", e))); + } + } + Ok(true) + } else { + Ok(false) + } + } + + /// `bitwise_xor_expression = bitwise_and_expression { "^" bitwise_and_expression }.` + fn is_bitwise_xor_expression(&mut self) -> Result { + if self.is_bitwise_and_expression()? { + while self.is_punctuation("^") { + if !self.is_bitwise_and_expression()? { + return Err(self.error_at("expected bitwise_and_expression")); + } + if self.context.stack_ids.len() >= 2 + && let Err(e) = self.op_lookup.lookup("^", &mut self.context, 2) + { + return Err(self.error_at(&format!("operation error: {}", e))); + } + } + Ok(true) + } else { + Ok(false) + } + } + + /// `bitwise_and_expression = bitwise_shift_expression { "&" bitwise_shift_expression }.` + fn is_bitwise_and_expression(&mut self) -> Result { + if self.is_bitwise_shift_expression()? { + while self.is_punctuation("&") { + if !self.is_bitwise_shift_expression()? { + return Err(self.error_at("expected bitwise_shift_expression")); + } + if self.context.stack_ids.len() >= 2 + && let Err(e) = self.op_lookup.lookup("&", &mut self.context, 2) + { + return Err(self.error_at(&format!("operation error: {}", e))); + } + } + Ok(true) + } else { + Ok(false) + } + } + + /// `bitwise_shift_expression = additive_expression { ("<<" | ">>") additive_expression }.` + fn is_bitwise_shift_expression(&mut self) -> Result { + if self.is_additive_expression()? { + loop { + let op_name = if self.is_punctuation("<<") { + Some("<<") + } else if self.is_punctuation(">>") { + Some(">>") + } else { + None + }; + + if let Some(op_name) = op_name { + if !self.is_additive_expression()? { + return Err(self.error_at("expected additive_expression")); + } + if self.context.stack_ids.len() >= 2 + && let Err(e) = self.op_lookup.lookup(op_name, &mut self.context, 2) + { + return Err(self.error_at(&format!("operation error: {}", e))); + } + } else { + break; + } + } + Ok(true) + } else { + Ok(false) + } + } + + /// `additive_expression = multiplicative_expression { ("+" | "-") multiplicative_expression }.` + fn is_additive_expression(&mut self) -> Result { + if self.is_multiplicative_expression()? { + loop { + // Check which operator we have + let op_name = if self.is_punctuation("+") { + Some("+") + } else if self.is_punctuation("-") { + Some("-") + } else { + None + }; + + // If we found an operator, parse the right operand and apply the operation + if let Some(op_name) = op_name { + if !self.is_multiplicative_expression()? { + return Err(self.error_at("expected multiplicative_expression")); + } + + if self.context.stack_ids.len() >= 2 + && let Err(e) = self.op_lookup.lookup(op_name, &mut self.context, 2) + { + return Err(self.error_at(&format!("operation error: {}", e))); + } + } else { + break; + } + } + Ok(true) + } else { + Ok(false) + } + } + + /// `multiplicative_expression = unary_expression { ("*" | "/" | "%") unary_expression }.` + fn is_multiplicative_expression(&mut self) -> Result { + if self.is_unary_expression()? { + loop { + // Check which operator we have + let op_name = if self.is_punctuation("*") { + Some("*") + } else if self.is_punctuation("/") { + Some("/") + } else if self.is_punctuation("%") { + Some("%") + } else { + None + }; + + // If we found an operator, parse the right operand and apply the operation + if let Some(op_name) = op_name { + if !self.is_unary_expression()? { + return Err(self.error_at("expected unary_expression")); + } + + if self.context.stack_ids.len() >= 2 + && let Err(e) = self.op_lookup.lookup(op_name, &mut self.context, 2) + { + return Err(self.error_at(&format!("operation error: {}", e))); + } + } else { + break; + } + } + Ok(true) + } else { + Ok(false) + } + } + + /// `unary_expression = (("-" | "!") unary_expression) | primary_expression.` + fn is_unary_expression(&mut self) -> Result { + // Check for unary operators + let op_name = if self.is_punctuation("-") { + Some("-") + } else if self.is_punctuation("!") { + Some("!") + } else { + None + }; + + if let Some(op_name) = op_name { + if !self.is_unary_expression()? { + return Err(self.error_at("expected unary_expression")); + } + // Apply the unary operation (only if we have types) + if self.context.stack_ids.len() >= 1 + && let Err(e) = self.op_lookup.lookup(op_name, &mut self.context, 1) + { + return Err(self.error_at(&format!("operation error: {}", e))); + } + Ok(true) + } else { + self.is_postfix_expression() + } + } + + /// `postfix_expression = primary_expression { "(" parameter_list ")" }.` + fn is_postfix_expression(&mut self) -> Result { + if !self.is_primary_expression()? { + return Ok(false); + } + while matches!( + self.peek_token(), + Some(Token::OpenDelim { + delimiter: Delimiter::Parenthesis, + .. + }) + ) { + self.advance(); // consume "(" + let arg_count = self.parameter_list()?; + match self.peek_token() { + Some(Token::CloseDelim { + delimiter: Delimiter::Parenthesis, + .. + }) => { + self.advance(); // consume ")" + } + _ => return Err(self.error_at("expected closing parenthesis")), + } + // Push the call operation: pops argument(s) then callee, invokes callee, pushes result. + // Stack order is [callee, arg1, arg2, ...]; lookup peeks top (arg_count + 1) entries. + if self.context.stack_ids.len() >= arg_count + 1 + && let Err(e) = self + .op_lookup + .lookup("()", &mut self.context, arg_count + 1) + { + return Err(self.error_at(&format!("call: {}", e))); + } + } + Ok(true) + } + + /// `parameter_list = [ or_expression { "," or_expression } ].` + /// + /// Returns the argument count. + fn parameter_list(&mut self) -> Result { + let mut count = 0; + if self.is_or_expression()? { + count += 1; + while self.is_punctuation(",") { + if !self.is_or_expression()? { + return Err(self.error_at("expected expression after comma")); + } + count += 1; + } + } + Ok(count) + } + + /// `primary_expression = literal | identifier | "(" or_expression ")".` + fn is_primary_expression(&mut self) -> Result { + match self.peek_token() { + Some(Token::Literal(lit)) => { + // Clone the literal - syn's Lit types are Clone + let lit_clone = lit.clone(); + self.advance(); + // Push the literal to the context + push_literal(&mut self.context, lit_clone); + Ok(true) + } + Some(Token::Identifier(ident)) => { + let ident_name = ident.to_string(); + self.advance(); + + // Look up identifier (variable/0-ary); value is pushed and may be a function. + self.op_lookup + .lookup(&ident_name, &mut self.context, 0) + .map_err(|_| self.error_at(&format!("Undefined identifier: {}", ident_name)))?; + + Ok(true) + } + Some(Token::OpenDelim { + delimiter: Delimiter::Parenthesis, + .. + }) => { + self.advance(); // consume OpenDelim + // Recursively parse the expression inside parentheses + if !self.is_or_expression()? { + return Err(self.error_at("expected expression")); + } + // Expect CloseDelim + match self.peek_token() { + Some(Token::CloseDelim { + delimiter: Delimiter::Parenthesis, + .. + }) => { + self.advance(); // consume CloseDelim + Ok(true) + } + _ => Err(self.error_at("expected closing parenthesis")), + } + } + _ => Ok(false), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow; + + #[test] + fn experiments() -> Result<()> { + let mut lookup = OpLookup::new(); + lookup.push_scope(|name, segment, _num_operands| { + if name == "constant" { + segment.just(42i64); + return Ok(true); + } + Ok(false) + }); + let mut parser = CELParser::new(lookup); + let line = line!() + 1; + let source = r#" + (("hello" + " world") == constant) && (15i64 < constant) + "#; + // assert!(parser.parse_str(source)?.call0::()?); + + if let Err(e) = parser.parse_str(source) { + println!("{}", e.format_rustc_style(source, file!(), line)); + } + Ok(()) + } + + #[test] + fn simple_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("10"); + assert!(result.is_ok()); + assert_eq!(result.unwrap().call0::().unwrap(), 10); + } + + #[test] + fn float_literal() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("3.14"); + assert!(result.is_ok()); + let value = result.unwrap().call0::().unwrap(); + assert!((value - 3.14).abs() < 1e-10); + } + + #[test] + fn boolean_literal() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("true"); + assert!(result.is_ok()); + assert_eq!(result.unwrap().call0::().unwrap(), true); + } + + #[test] + fn string_literal() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str(r#""hello""#); + assert!(result.is_ok()); + assert_eq!(result.unwrap().call0::().unwrap(), "hello"); + } + + #[test] + fn string_concatenation() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str(r#""a" + "b""#); + assert!(result.is_ok()); + assert_eq!(result.unwrap().call0::().unwrap(), "ab"); + } + + #[test] + fn incomplete_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("10 + 25 25"); + let err = match result { + Ok(_) => panic!("expected parse error"), + Err(e) => e, + }; + assert_eq!(err.message(), "unexpected token"); + } + + #[test] + fn arithmetic_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("10 + 20 * 30"); + assert!(result.is_ok()); + } + + #[test] + fn parenthesized_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("(10 + 20) * 30"); + assert!(result.is_ok()); + } + + #[test] + fn complex_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("10 + 20 * (30 - 5) / 2"); + assert!(result.is_ok()); + } + + #[test] + fn logical_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("true && false || true"); + assert!(result.is_ok()); + } + + #[test] + fn comparison_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("10 == 20 && 30 > 40"); + assert!(result.is_ok()); + } + + #[test] + fn bitwise_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("1 | 2 & 3 ^ 4"); + assert!(result.is_ok()); + } + + #[test] + fn shift_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("8 << 2 + 16 >> 1"); + assert!(result.is_ok()); + } + + #[test] + fn unary_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("-10 + -20"); + assert!(result.is_ok()); + } + + #[test] + fn double_negation() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("!!true"); + assert!( + result.is_ok(), + "Failed to parse !!true: {}", + result.err().unwrap() + ); + assert_eq!(result.unwrap().call0::().unwrap(), true); // !!true = true + } + + #[test] + fn double_minus() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("--5"); + assert!( + result.is_ok(), + "Failed to parse --5: {}", + result.err().unwrap() + ); + assert_eq!(result.unwrap().call0::().unwrap(), 5); + } + + #[test] + fn chained_unary_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("!!false || !!true"); + if let Err(ref e) = result { + eprintln!("Error: {:?}", e); + eprintln!("Error message: {}", e.message()); + } + assert!(result.is_ok(), "Failed to parse: {}", result.err().unwrap()); + assert_eq!(result.unwrap().call0::().unwrap(), true); + } + + #[test] + fn invalid_expression() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("+"); + assert!(result.is_err()); + } + + /// Helper function to strip ANSI escape codes from a string for testing purposes + fn strip_ansi_codes(input: &str) -> String { + // Basic regex to remove ANSI escape sequences + // ANSI escape sequences start with ESC (0x1B) followed by '[' and end with a letter + let mut result = String::new(); + let mut chars = input.chars().peekable(); + + while let Some(ch) = chars.next() { + if ch == '\x1B' { + // Found ESC, check if it's followed by '[' + if chars.peek() == Some(&'[') { + chars.next(); // consume '[' + // Skip until we find a letter (which ends the escape sequence) + while let Some(ch) = chars.next() { + if ch.is_ascii_alphabetic() { + break; + } + } + } else { + result.push(ch); + } + } else { + result.push(ch); + } + } + + result + } + + #[test] + fn error_formatting() { + let source = "10 + 20 30"; // Missing operator between 20 and 30 + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str(source); + + // This should fail parsing + assert!(result.is_err()); + + // Test error message from result + let err = match &result { + Ok(_) => panic!("expected parse error"), + Err(e) => e, + }; + assert_eq!(err.message(), "unexpected token"); + + // Test error formatting + let formatted_error = err.format_rustc_style(source, "test.cel", 1u32); + + // Strip ANSI codes for testing + let formatted = strip_ansi_codes(&formatted_error); + assert!(formatted.contains("error: unexpected token")); + assert!(formatted.contains("test.cel:1:")); // Should include line number + assert!(formatted.contains("1 | 10 + 20 30")); // Should show the line with line number + assert!(formatted.contains("^")); // Should have carets pointing to the error + } + + #[test] + fn error_formatting_with_line_offset() { + let source = "10 + 20 30"; // Missing operator between 20 and 30 + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str(source); + + // This should fail parsing + assert!(result.is_err()); + + // Test error formatting with line offset (as if expression starts at line 42) + let err = match &result { + Ok(_) => panic!("expected parse error"), + Err(e) => e, + }; + let formatted_error = err.format_rustc_style(source, "large_file.rs", 42u32); + + // Strip ANSI codes for testing + let formatted = strip_ansi_codes(&formatted_error); + assert!(formatted.contains("error: unexpected token")); + assert!(formatted.contains("large_file.rs:42:")); // Should show offset line number + assert!(formatted.contains("42 | 10 + 20 30")); // Should show the line with offset line number + assert!(formatted.contains("^")); // Should have carets pointing to the error + } + + #[test] + fn print_error_formatting() { + let line = line!() + 1; + let source = r#" + + 10 + 20 30 // Unexpected token + + "#; + + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str(source); + + // Parse should fail due to unexpected token + assert!(result.is_err(), "Expected parsing to fail"); + + let err = match &result { + Ok(_) => panic!("expected parse error"), + Err(e) => e, + }; + eprintln!( + "DEBUG: span.start.line = {}, span.start.column = {}", + err.span().start.line, + err.span().start.column + ); + + // Format the error + let formatted_error = err.format_rustc_style(source, file!(), line); + println!("{}", formatted_error); + + // Strip ANSI codes for testing + let formatted = strip_ansi_codes(&formatted_error); + + // The source string has 3 lines: + // Line 0: empty + // Line 1: empty + // Line 2: " 10 + 20 30 // Unexpected token" + // So the error should be on line + 2 + let expected_line = line + 2; + + assert!( + formatted.contains("error: unexpected token"), + "Should contain error message, got: {}", + formatted + ); + assert!( + formatted.contains(&format!("{}:", expected_line)), + "Should show error on line {}, got: {}", + expected_line, + formatted + ); + assert!( + formatted.contains("30"), + "Should show the source line with '30', got: {}", + formatted + ); + assert!( + formatted.contains("^"), + "Should have carets pointing to error, got: {}", + formatted + ); + } + + #[test] + fn test_addition_execution() -> anyhow::Result<()> { + let mut parser = CELParser::new(OpLookup::new()); + let mut segment = parser + .parse_str("10 + 20") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, 30); + Ok(()) + } + + #[test] + fn test_multiplication_execution() -> anyhow::Result<()> { + let mut parser = CELParser::new(OpLookup::new()); + let mut segment = parser + .parse_str("3 * 7") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, 21); + Ok(()) + } + + #[test] + fn test_complex_arithmetic_execution() -> anyhow::Result<()> { + let mut parser = CELParser::new(OpLookup::new()); + let mut segment = parser + .parse_str("10 + 20 * 3") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, 70); // 10 + (20 * 3) = 10 + 60 = 70 + Ok(()) + } + + #[test] + fn test_parenthesized_arithmetic_execution() -> anyhow::Result<()> { + let mut parser = CELParser::new(OpLookup::new()); + let mut segment = parser + .parse_str("(10 + 20) * 3") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, 90); // (10 + 20) * 3 = 30 * 3 = 90 + Ok(()) + } + + #[test] + fn test_comparison_execution() -> anyhow::Result<()> { + let mut parser = CELParser::new(OpLookup::new()); + let mut segment = parser + .parse_str("10 < 20") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, true); + Ok(()) + } + + #[test] + fn test_logical_and_execution() -> anyhow::Result<()> { + let mut parser = CELParser::new(OpLookup::new()); + let mut segment = parser + .parse_str("true && false") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, false); + Ok(()) + } + + #[test] + fn test_unary_negation_execution() -> anyhow::Result<()> { + let mut parser = CELParser::new(OpLookup::new()); + let mut segment = parser + .parse_str("-42") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, -42); + Ok(()) + } + + #[test] + fn test_logical_not_execution() -> anyhow::Result<()> { + let mut parser = CELParser::new(OpLookup::new()); + let mut segment = parser + .parse_str("!true") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, false); + Ok(()) + } + + #[test] + fn test_u32_addition_execution() -> anyhow::Result<()> { + let mut parser = CELParser::new(OpLookup::new()); + let mut segment = parser + .parse_str("10u32 + 20u32") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, 30); + Ok(()) + } + + #[test] + fn test_identifier_with_scope() -> anyhow::Result<()> { + let mut lookup = OpLookup::new(); + lookup.push_scope(|name, segment, num_operands| { + if num_operands == 0 { + match name { + "x" => { + segment.op0(|| 10i32); + Ok(true) + } + "y" => { + segment.op0(|| 20i32); + Ok(true) + } + _ => Ok(false), + } + } else { + Ok(false) + } + }); + let mut parser = CELParser::new(lookup); + let mut segment = parser + .parse_str("x + y") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, 30); + Ok(()) + } + + #[test] + fn test_undefined_identifier_error() { + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str("undefined_var + 10"); + + assert!(result.is_err()); + if let Err(e) = result { + let error_msg = format!("{:?}", e); + assert!( + error_msg.contains("Undefined identifier: undefined_var"), + "Error message should contain 'Undefined identifier: undefined_var', got: {}", + error_msg + ); + } + } + + #[test] + fn test_undefined_identifier_error_formatting() { + let input = "undefined_var + 10"; + let mut parser = CELParser::new(OpLookup::new()); + let result = parser.parse_str(input); + + assert!(result.is_err()); + if let Err(e) = result { + let formatted_error = e.format_rustc_style(input, "test.cel", 1); + assert!(formatted_error.contains("Undefined identifier")); + assert!(formatted_error.contains("undefined_var")); + assert!(formatted_error.contains("test.cel")); + } + } + + #[test] + fn test_float_arithmetic_execution() -> anyhow::Result<()> { + let mut parser = CELParser::new(OpLookup::new()); + let mut segment = parser + .parse_str("3.5 * 2.0") + .map_err(|e| anyhow::anyhow!("{}", e))?; + let result = segment.call0::()?; + assert_eq!(result, 7.0); + Ok(()) + } +} diff --git a/cel-runtime/src/parser/op_table.rs b/cel-runtime/src/parser/op_table.rs new file mode 100644 index 0000000..10c3506 --- /dev/null +++ b/cel-runtime/src/parser/op_table.rs @@ -0,0 +1,951 @@ +//! Operation table for dynamically dispatching operations based on type signatures. +//! +//! This module provides a scope-based registry for operations that can be looked up +//! based on an operation name (string) and the types of the operands. Built-in operations +//! use compile-time hash tables (via `phf`) for efficient lookup, while custom operations +//! can be added dynamically through scope functions. +//! +//! # Design +//! +//! - **Operator symbols as names**: Operations are identified by their operator symbols +//! (e.g., `"+"`, `"-"`, `"*"`) to avoid conflicts with valid identifiers. +//! - **Function pointers**: Built-in operations use stateless function pointers for +//! zero-allocation dispatch. +//! - **Scope stack**: Custom operations are handled through a stack of scope functions +//! that can be pushed and popped as needed. +//! - **Type optimization**: Since all built-in operations have matching operand types, +//! signatures store a single `TypeId` plus arity rather than arrays. + +use crate::DynSegment; +use anyhow::{Result, anyhow}; +use once_cell::sync::Lazy; +use phf::phf_map; +use std::any::TypeId; + +/// A function that pushes an operation onto a DynSegment. +/// +/// This is a simple function pointer since built-in operations have no state. +pub type OpFn = fn(&mut DynSegment) -> Result<()>; + +/// A scope function that attempts to resolve and apply an operation. +/// +/// Receives the operation name, the segment, and the number of operands on top of the stack. +/// The scope may call `segment.peek_stack_infos(num_operands)` to inspect types. Returns +/// `Ok(true)` if handled, `Ok(false)` if not found, or `Err` on error. +pub type ScopeFn = Box Result + Send + Sync>; + +/// A signature for an operation with matching operand types. +/// +/// For example, `u32 + u32 -> u32` would have `type_id_index = TYPE_U32` +/// and `arity = 2`. This optimization reduces memory usage by ~50% compared to +/// storing full type arrays. +#[derive(Clone, Copy)] +struct OpSignature { + /// Index into TYPE_IDS vector for the TypeId that all operands must match + type_id_index: usize, + /// Number of operands this operation accepts + arity: u8, + /// Function pointer to the operation implementation + op_fn: OpFn, +} + +impl OpSignature { + /// Returns the TypeId for this signature. + fn type_id(&self) -> TypeId { + TYPE_IDS[self.type_id_index] + } +} + +/// Single lazy-initialized vector containing all unique TypeIds for built-in types. +/// +/// This avoids duplicating TypeId storage across all operation signatures. +static TYPE_IDS: Lazy> = Lazy::new(|| { + vec![ + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + TypeId::of::(), + ] +}); + +// Type index constants for readability +const TYPE_U8: usize = 0; +const TYPE_U16: usize = 1; +const TYPE_U32: usize = 2; +const TYPE_U64: usize = 3; +const TYPE_U128: usize = 4; +const TYPE_USIZE: usize = 5; +const TYPE_I8: usize = 6; +const TYPE_I16: usize = 7; +const TYPE_I32: usize = 8; +const TYPE_I64: usize = 9; +const TYPE_I128: usize = 10; +const TYPE_ISIZE: usize = 11; +const TYPE_F32: usize = 12; +const TYPE_F64: usize = 13; +const TYPE_BOOL: usize = 14; +const TYPE_STR: usize = 15; + +// Helper macro to reduce boilerplate in signature definitions +macro_rules! sig { + ($type_idx:expr, $arity:expr, $closure:expr) => { + OpSignature { + type_id_index: $type_idx, + arity: $arity, + op_fn: $closure, + } + }; +} + +// Addition signatures +static ADD_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a.wrapping_add(b))), + sig!(TYPE_U16, 2, |seg| seg + .op2(|a: u16, b: u16| a.wrapping_add(b))), + sig!(TYPE_U32, 2, |seg| seg + .op2(|a: u32, b: u32| a.wrapping_add(b))), + sig!(TYPE_U64, 2, |seg| seg + .op2(|a: u64, b: u64| a.wrapping_add(b))), + sig!(TYPE_U128, 2, |seg| seg + .op2(|a: u128, b: u128| a.wrapping_add(b))), + sig!(TYPE_USIZE, 2, |seg| seg + .op2(|a: usize, b: usize| a.wrapping_add(b))), + sig!(TYPE_I8, 2, |seg| seg.op2r(|a: i8, b: i8| a + .checked_add(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I16, 2, |seg| seg.op2r(|a: i16, b: i16| a + .checked_add(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I32, 2, |seg| seg.op2r(|a: i32, b: i32| a + .checked_add(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I64, 2, |seg| seg.op2r(|a: i64, b: i64| a + .checked_add(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I128, 2, |seg| seg.op2r(|a: i128, b: i128| a + .checked_add(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_ISIZE, 2, |seg| seg.op2r(|a: isize, b: isize| a + .checked_add(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a + b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a + b)), + sig!(TYPE_STR, 2, |seg| seg.op2(|a: String, b: String| a + &b)), +]; + +// Subtraction signatures (both binary and unary) +static SUB_SIGNATURES: &[OpSignature] = &[ + // Binary subtraction + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a.wrapping_sub(b))), + sig!(TYPE_U16, 2, |seg| seg + .op2(|a: u16, b: u16| a.wrapping_sub(b))), + sig!(TYPE_U32, 2, |seg| seg + .op2(|a: u32, b: u32| a.wrapping_sub(b))), + sig!(TYPE_U64, 2, |seg| seg + .op2(|a: u64, b: u64| a.wrapping_sub(b))), + sig!(TYPE_U128, 2, |seg| seg + .op2(|a: u128, b: u128| a.wrapping_sub(b))), + sig!(TYPE_USIZE, 2, |seg| seg + .op2(|a: usize, b: usize| a.wrapping_sub(b))), + sig!(TYPE_I8, 2, |seg| seg.op2r(|a: i8, b: i8| a + .checked_sub(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I16, 2, |seg| seg.op2r(|a: i16, b: i16| a + .checked_sub(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I32, 2, |seg| seg.op2r(|a: i32, b: i32| a + .checked_sub(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I64, 2, |seg| seg.op2r(|a: i64, b: i64| a + .checked_sub(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I128, 2, |seg| seg.op2r(|a: i128, b: i128| a + .checked_sub(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_ISIZE, 2, |seg| seg.op2r(|a: isize, b: isize| a + .checked_sub(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a - b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a - b)), + // Unary negation + sig!(TYPE_I8, 1, |seg| seg.op1r(|a: i8| a + .checked_neg() + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I16, 1, |seg| seg.op1r(|a: i16| a + .checked_neg() + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I32, 1, |seg| seg.op1r(|a: i32| a + .checked_neg() + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I64, 1, |seg| seg.op1r(|a: i64| a + .checked_neg() + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I128, 1, |seg| seg.op1r(|a: i128| a + .checked_neg() + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_ISIZE, 1, |seg| seg.op1r(|a: isize| a + .checked_neg() + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_F32, 1, |seg| seg.op1(|a: f32| -a)), + sig!(TYPE_F64, 1, |seg| seg.op1(|a: f64| -a)), +]; + +// Multiplication signatures +static MUL_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a.wrapping_mul(b))), + sig!(TYPE_U16, 2, |seg| seg + .op2(|a: u16, b: u16| a.wrapping_mul(b))), + sig!(TYPE_U32, 2, |seg| seg + .op2(|a: u32, b: u32| a.wrapping_mul(b))), + sig!(TYPE_U64, 2, |seg| seg + .op2(|a: u64, b: u64| a.wrapping_mul(b))), + sig!(TYPE_U128, 2, |seg| seg + .op2(|a: u128, b: u128| a.wrapping_mul(b))), + sig!(TYPE_USIZE, 2, |seg| seg + .op2(|a: usize, b: usize| a.wrapping_mul(b))), + sig!(TYPE_I8, 2, |seg| seg.op2r(|a: i8, b: i8| a + .checked_mul(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I16, 2, |seg| seg.op2r(|a: i16, b: i16| a + .checked_mul(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I32, 2, |seg| seg.op2r(|a: i32, b: i32| a + .checked_mul(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I64, 2, |seg| seg.op2r(|a: i64, b: i64| a + .checked_mul(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_I128, 2, |seg| seg.op2r(|a: i128, b: i128| a + .checked_mul(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_ISIZE, 2, |seg| seg.op2r(|a: isize, b: isize| a + .checked_mul(b) + .ok_or_else(|| anyhow!("arithmetic overflow")))), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a * b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a * b)), +]; + +// Division signatures +// +// Integer division uses `checked_div` via `op2r` so that division by zero returns an error +// instead of panicking. Float division keeps `op2` (IEEE 754 defines x/0.0 as inf/nan). +static DIV_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2r(|a: u8, b: u8| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_U16, 2, |seg| seg.op2r(|a: u16, b: u16| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_U32, 2, |seg| seg.op2r(|a: u32, b: u32| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_U64, 2, |seg| seg.op2r(|a: u64, b: u64| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_U128, 2, |seg| seg.op2r(|a: u128, b: u128| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_USIZE, 2, |seg| seg.op2r(|a: usize, b: usize| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_I8, 2, |seg| seg.op2r(|a: i8, b: i8| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_I16, 2, |seg| seg.op2r(|a: i16, b: i16| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_I32, 2, |seg| seg.op2r(|a: i32, b: i32| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_I64, 2, |seg| seg.op2r(|a: i64, b: i64| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_I128, 2, |seg| seg.op2r(|a: i128, b: i128| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_ISIZE, 2, |seg| seg.op2r(|a: isize, b: isize| a + .checked_div(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a / b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a / b)), +]; + +// Modulo signatures +// +// Integer modulo uses `checked_rem` via `op2r` so that division by zero returns an error +// instead of panicking. Float modulo keeps `op2` (x % 0.0 yields NaN without panicking). +static MOD_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2r(|a: u8, b: u8| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_U16, 2, |seg| seg.op2r(|a: u16, b: u16| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_U32, 2, |seg| seg.op2r(|a: u32, b: u32| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_U64, 2, |seg| seg.op2r(|a: u64, b: u64| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_U128, 2, |seg| seg.op2r(|a: u128, b: u128| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_USIZE, 2, |seg| seg.op2r(|a: usize, b: usize| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_I8, 2, |seg| seg.op2r(|a: i8, b: i8| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_I16, 2, |seg| seg.op2r(|a: i16, b: i16| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_I32, 2, |seg| seg.op2r(|a: i32, b: i32| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_I64, 2, |seg| seg.op2r(|a: i64, b: i64| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_I128, 2, |seg| seg.op2r(|a: i128, b: i128| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_ISIZE, 2, |seg| seg.op2r(|a: isize, b: isize| a + .checked_rem(b) + .ok_or_else(|| anyhow!("division by zero")))), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a % b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a % b)), +]; + +// Bitwise AND signatures +static BITWISE_AND_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a & b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a & b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a & b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a & b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a & b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a & b)), + sig!(TYPE_I8, 2, |seg| seg.op2(|a: i8, b: i8| a & b)), + sig!(TYPE_I16, 2, |seg| seg.op2(|a: i16, b: i16| a & b)), + sig!(TYPE_I32, 2, |seg| seg.op2(|a: i32, b: i32| a & b)), + sig!(TYPE_I64, 2, |seg| seg.op2(|a: i64, b: i64| a & b)), + sig!(TYPE_I128, 2, |seg| seg.op2(|a: i128, b: i128| a & b)), + sig!(TYPE_ISIZE, 2, |seg| seg.op2(|a: isize, b: isize| a & b)), +]; + +// Bitwise OR signatures +static BITWISE_OR_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a | b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a | b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a | b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a | b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a | b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a | b)), + sig!(TYPE_I8, 2, |seg| seg.op2(|a: i8, b: i8| a | b)), + sig!(TYPE_I16, 2, |seg| seg.op2(|a: i16, b: i16| a | b)), + sig!(TYPE_I32, 2, |seg| seg.op2(|a: i32, b: i32| a | b)), + sig!(TYPE_I64, 2, |seg| seg.op2(|a: i64, b: i64| a | b)), + sig!(TYPE_I128, 2, |seg| seg.op2(|a: i128, b: i128| a | b)), + sig!(TYPE_ISIZE, 2, |seg| seg.op2(|a: isize, b: isize| a | b)), +]; + +// Bitwise XOR signatures +static BITWISE_XOR_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a ^ b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a ^ b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a ^ b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a ^ b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a ^ b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a ^ b)), + sig!(TYPE_I8, 2, |seg| seg.op2(|a: i8, b: i8| a ^ b)), + sig!(TYPE_I16, 2, |seg| seg.op2(|a: i16, b: i16| a ^ b)), + sig!(TYPE_I32, 2, |seg| seg.op2(|a: i32, b: i32| a ^ b)), + sig!(TYPE_I64, 2, |seg| seg.op2(|a: i64, b: i64| a ^ b)), + sig!(TYPE_I128, 2, |seg| seg.op2(|a: i128, b: i128| a ^ b)), + sig!(TYPE_ISIZE, 2, |seg| seg.op2(|a: isize, b: isize| a ^ b)), +]; + +// Left shift signatures +static LEFT_SHIFT_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a << b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a << b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a << b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a << b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a << b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a << b)), + sig!(TYPE_I8, 2, |seg| seg.op2r(|a: i8, b: i8| a + .checked_shl(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), + sig!(TYPE_I16, 2, |seg| seg.op2r(|a: i16, b: i16| a + .checked_shl(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), + sig!(TYPE_I32, 2, |seg| seg.op2r(|a: i32, b: i32| a + .checked_shl(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), + sig!(TYPE_I64, 2, |seg| seg.op2r(|a: i64, b: i64| a + .checked_shl(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), + sig!(TYPE_I128, 2, |seg| seg.op2r(|a: i128, b: i128| a + .checked_shl(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), + sig!(TYPE_ISIZE, 2, |seg| seg.op2r(|a: isize, b: isize| a + .checked_shl(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), +]; + +// Right shift signatures +static RIGHT_SHIFT_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a >> b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a >> b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a >> b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a >> b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a >> b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a >> b)), + sig!(TYPE_I8, 2, |seg| seg.op2r(|a: i8, b: i8| a + .checked_shr(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), + sig!(TYPE_I16, 2, |seg| seg.op2r(|a: i16, b: i16| a + .checked_shr(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), + sig!(TYPE_I32, 2, |seg| seg.op2r(|a: i32, b: i32| a + .checked_shr(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), + sig!(TYPE_I64, 2, |seg| seg.op2r(|a: i64, b: i64| a + .checked_shr(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), + sig!(TYPE_I128, 2, |seg| seg.op2r(|a: i128, b: i128| a + .checked_shr(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), + sig!(TYPE_ISIZE, 2, |seg| seg.op2r(|a: isize, b: isize| a + .checked_shr(b as u32) + .ok_or_else(|| anyhow!("shift overflow")))), +]; + +// Logical AND signatures +static LOGICAL_AND_SIGNATURES: &[OpSignature] = + &[sig!(TYPE_BOOL, 2, |seg| seg.op2(|a: bool, b: bool| a && b))]; + +// Logical OR signatures +static LOGICAL_OR_SIGNATURES: &[OpSignature] = + &[sig!(TYPE_BOOL, 2, |seg| seg.op2(|a: bool, b: bool| a || b))]; + +// Logical NOT signatures +static LOGICAL_NOT_SIGNATURES: &[OpSignature] = &[sig!(TYPE_BOOL, 1, |seg| seg.op1(|a: bool| !a))]; + +// Equality signatures +static EQUAL_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a == b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a == b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a == b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a == b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a == b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a == b)), + sig!(TYPE_I8, 2, |seg| seg.op2(|a: i8, b: i8| a == b)), + sig!(TYPE_I16, 2, |seg| seg.op2(|a: i16, b: i16| a == b)), + sig!(TYPE_I32, 2, |seg| seg.op2(|a: i32, b: i32| a == b)), + sig!(TYPE_I64, 2, |seg| seg.op2(|a: i64, b: i64| a == b)), + sig!(TYPE_I128, 2, |seg| seg.op2(|a: i128, b: i128| a == b)), + sig!(TYPE_ISIZE, 2, |seg| seg.op2(|a: isize, b: isize| a == b)), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a == b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a == b)), + sig!(TYPE_BOOL, 2, |seg| seg.op2(|a: bool, b: bool| a == b)), + sig!(TYPE_STR, 2, |seg| seg.op2(|a: String, b: String| a == b)), +]; + +// Inequality signatures +static NOT_EQUAL_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a != b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a != b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a != b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a != b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a != b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a != b)), + sig!(TYPE_I8, 2, |seg| seg.op2(|a: i8, b: i8| a != b)), + sig!(TYPE_I16, 2, |seg| seg.op2(|a: i16, b: i16| a != b)), + sig!(TYPE_I32, 2, |seg| seg.op2(|a: i32, b: i32| a != b)), + sig!(TYPE_I64, 2, |seg| seg.op2(|a: i64, b: i64| a != b)), + sig!(TYPE_I128, 2, |seg| seg.op2(|a: i128, b: i128| a != b)), + sig!(TYPE_ISIZE, 2, |seg| seg.op2(|a: isize, b: isize| a != b)), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a != b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a != b)), + sig!(TYPE_BOOL, 2, |seg| seg.op2(|a: bool, b: bool| a != b)), + sig!(TYPE_STR, 2, |seg| seg.op2(|a: String, b: String| a != b)), +]; + +// Less than signatures +static LESS_THAN_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a < b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a < b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a < b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a < b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a < b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a < b)), + sig!(TYPE_I8, 2, |seg| seg.op2(|a: i8, b: i8| a < b)), + sig!(TYPE_I16, 2, |seg| seg.op2(|a: i16, b: i16| a < b)), + sig!(TYPE_I32, 2, |seg| seg.op2(|a: i32, b: i32| a < b)), + sig!(TYPE_I64, 2, |seg| seg.op2(|a: i64, b: i64| a < b)), + sig!(TYPE_I128, 2, |seg| seg.op2(|a: i128, b: i128| a < b)), + sig!(TYPE_ISIZE, 2, |seg| seg.op2(|a: isize, b: isize| a < b)), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a < b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a < b)), + sig!(TYPE_STR, 2, |seg| seg.op2(|a: String, b: String| a < b)), +]; + +// Less than or equal signatures +static LESS_THAN_OR_EQUAL_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a <= b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a <= b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a <= b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a <= b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a <= b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a <= b)), + sig!(TYPE_I8, 2, |seg| seg.op2(|a: i8, b: i8| a <= b)), + sig!(TYPE_I16, 2, |seg| seg.op2(|a: i16, b: i16| a <= b)), + sig!(TYPE_I32, 2, |seg| seg.op2(|a: i32, b: i32| a <= b)), + sig!(TYPE_I64, 2, |seg| seg.op2(|a: i64, b: i64| a <= b)), + sig!(TYPE_I128, 2, |seg| seg.op2(|a: i128, b: i128| a <= b)), + sig!(TYPE_ISIZE, 2, |seg| seg.op2(|a: isize, b: isize| a <= b)), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a <= b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a <= b)), + sig!(TYPE_STR, 2, |seg| seg.op2(|a: String, b: String| a <= b)), +]; + +// Greater than signatures +static GREATER_THAN_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a > b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a > b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a > b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a > b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a > b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a > b)), + sig!(TYPE_I8, 2, |seg| seg.op2(|a: i8, b: i8| a > b)), + sig!(TYPE_I16, 2, |seg| seg.op2(|a: i16, b: i16| a > b)), + sig!(TYPE_I32, 2, |seg| seg.op2(|a: i32, b: i32| a > b)), + sig!(TYPE_I64, 2, |seg| seg.op2(|a: i64, b: i64| a > b)), + sig!(TYPE_I128, 2, |seg| seg.op2(|a: i128, b: i128| a > b)), + sig!(TYPE_ISIZE, 2, |seg| seg.op2(|a: isize, b: isize| a > b)), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a > b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a > b)), + sig!(TYPE_STR, 2, |seg| seg.op2(|a: String, b: String| a > b)), +]; + +// Greater than or equal signatures +static GREATER_THAN_OR_EQUAL_SIGNATURES: &[OpSignature] = &[ + sig!(TYPE_U8, 2, |seg| seg.op2(|a: u8, b: u8| a >= b)), + sig!(TYPE_U16, 2, |seg| seg.op2(|a: u16, b: u16| a >= b)), + sig!(TYPE_U32, 2, |seg| seg.op2(|a: u32, b: u32| a >= b)), + sig!(TYPE_U64, 2, |seg| seg.op2(|a: u64, b: u64| a >= b)), + sig!(TYPE_U128, 2, |seg| seg.op2(|a: u128, b: u128| a >= b)), + sig!(TYPE_USIZE, 2, |seg| seg.op2(|a: usize, b: usize| a >= b)), + sig!(TYPE_I8, 2, |seg| seg.op2(|a: i8, b: i8| a >= b)), + sig!(TYPE_I16, 2, |seg| seg.op2(|a: i16, b: i16| a >= b)), + sig!(TYPE_I32, 2, |seg| seg.op2(|a: i32, b: i32| a >= b)), + sig!(TYPE_I64, 2, |seg| seg.op2(|a: i64, b: i64| a >= b)), + sig!(TYPE_I128, 2, |seg| seg.op2(|a: i128, b: i128| a >= b)), + sig!(TYPE_ISIZE, 2, |seg| seg.op2(|a: isize, b: isize| a >= b)), + sig!(TYPE_F32, 2, |seg| seg.op2(|a: f32, b: f32| a >= b)), + sig!(TYPE_F64, 2, |seg| seg.op2(|a: f64, b: f64| a >= b)), + sig!(TYPE_STR, 2, |seg| seg.op2(|a: String, b: String| a >= b)), +]; + +/// Compile-time perfect hash map for built-in operations. +/// +/// Maps operator symbols to their signature arrays for O(1) lookup. +static BUILTINS: phf::Map<&'static str, &'static [OpSignature]> = phf_map! { + "+" => ADD_SIGNATURES, + "-" => SUB_SIGNATURES, + "*" => MUL_SIGNATURES, + "/" => DIV_SIGNATURES, + "%" => MOD_SIGNATURES, + "&" => BITWISE_AND_SIGNATURES, + "|" => BITWISE_OR_SIGNATURES, + "^" => BITWISE_XOR_SIGNATURES, + "<<" => LEFT_SHIFT_SIGNATURES, + ">>" => RIGHT_SHIFT_SIGNATURES, + "&&" => LOGICAL_AND_SIGNATURES, + "||" => LOGICAL_OR_SIGNATURES, + "!" => LOGICAL_NOT_SIGNATURES, + "==" => EQUAL_SIGNATURES, + "!=" => NOT_EQUAL_SIGNATURES, + "<" => LESS_THAN_SIGNATURES, + "<=" => LESS_THAN_OR_EQUAL_SIGNATURES, + ">" => GREATER_THAN_SIGNATURES, + ">=" => GREATER_THAN_OR_EQUAL_SIGNATURES, +}; + +/// Built-in operation scope. +/// +/// Provides lookup for standard operations using a compile-time hash table. +struct BuiltinScope; + +impl BuiltinScope { + /// Attempts to find and apply a built-in operation. + /// + /// Returns `Ok(true)` if found and applied, `Ok(false)` if not found. + fn lookup(&self, name: &str, segment: &mut DynSegment, num_operands: usize) -> Result { + let stack_infos = segment.peek_stack_infos(num_operands); + if let Some(signatures) = BUILTINS.get(name) { + for sig in *signatures { + let matches = sig.arity as usize == stack_infos.len() + && stack_infos.iter().all(|info| info.type_id == sig.type_id()); + + if matches { + (sig.op_fn)(segment)?; + return Ok(true); + } + } + } + Ok(false) + } +} + +/// Operation lookup with scope stack support. +/// +/// Provides a stack of scopes for operation resolution, with built-in operations +/// as the fallback. Scopes are searched in LIFO order (most recently pushed first). +/// +/// # Examples +/// +/// ```rust +/// use cel_runtime::parser::op_table::OpLookup; +/// use cel_runtime::DynSegment; +/// use std::any::TypeId; +/// +/// let mut lookup = OpLookup::new(); +/// +/// // Use built-in addition +/// let mut segment = DynSegment::new::<()>(); +/// segment.just(10u32); +/// segment.just(20u32); +/// lookup.lookup("+", &mut segment, 2).unwrap(); +/// assert_eq!(segment.call0::().unwrap(), 30); +/// ``` +pub struct OpLookup { + scopes: Vec, + builtin_scope: BuiltinScope, +} + +impl OpLookup { + /// Creates a new operation lookup with only built-in operations. + pub fn new() -> Self { + OpLookup { + scopes: Vec::new(), + builtin_scope: BuiltinScope, + } + } + + /// Pushes a new scope onto the stack. + /// + /// Accepts a closure directly; it is boxed internally. The scope should return + /// `Ok(true)` if it handled the operation, `Ok(false)` to pass to the next scope, + /// or `Err` on error. + pub fn push_scope(&mut self, scope: F) + where + F: Fn(&str, &mut DynSegment, usize) -> Result + Send + Sync + 'static, + { + self.scopes.push(Box::new(scope)); + } + + /// Pops the most recent scope from the stack. + /// + /// Returns the popped scope, or `None` if the stack is empty. + pub fn pop_scope(&mut self) -> Option { + self.scopes.pop() + } + + /// Looks up and applies an operation. + /// + /// Searches scopes in LIFO order, then falls back to built-in operations. + /// + /// # Arguments + /// + /// * `name` - The operation name (e.g., `"+"`, `"-"`, or a custom identifier) + /// * `segment` - The DynSegment to apply the operation to + /// * `num_operands` - Number of top stack entries that are operands (e.g. 2 for binary ops) + /// + /// # Errors + /// + /// Returns an error if no scope or built-in operation can handle the request. + /// Error messages report type names from the top stack entries, not raw type ids. + pub fn lookup(&self, name: &str, segment: &mut DynSegment, num_operands: usize) -> Result<()> { + for scope in self.scopes.iter().rev() { + if scope(name, segment, num_operands)? { + return Ok(()); + } + } + + if self.builtin_scope.lookup(name, segment, num_operands)? { + return Ok(()); + } + + let infos = segment.peek_stack_infos(num_operands); + let mut type_names = String::new(); + for (i, info) in infos.iter().enumerate() { + if i > 0 { + type_names.push_str(", "); + } + type_names.push_str(info.type_name.as_ref()); + } + Err(anyhow!( + "Operation '{}' not found for types [{}]", + name, + type_names + )) + } +} + +impl Default for OpLookup { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_addition_u32() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(10u32); + segment.just(20u32); + lookup.lookup("+", &mut segment, 2)?; + assert_eq!(segment.call0::()?, 30); + Ok(()) + } + + #[test] + fn test_subtraction_i32() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(50i32); + segment.just(20i32); + lookup.lookup("-", &mut segment, 2)?; + assert_eq!(segment.call0::()?, 30); + Ok(()) + } + + #[test] + fn test_arithmetic_overflow() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(i32::MAX); + segment.just(1i32); + lookup.lookup("+", &mut segment, 2)?; + let result = segment.call0::(); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("arithmetic overflow"), + "error message should mention arithmetic overflow" + ); + Ok(()) + } + + #[test] + fn test_division_by_zero() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(10i32); + segment.just(0i32); + lookup.lookup("/", &mut segment, 2)?; + let result = segment.call0::(); + assert!(result.is_err()); + assert!( + result.unwrap_err().to_string().contains("division by zero"), + "error message should mention division by zero" + ); + Ok(()) + } + + #[test] + fn test_modulo_by_zero() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(10u32); + segment.just(0u32); + lookup.lookup("%", &mut segment, 2)?; + let result = segment.call0::(); + assert!(result.is_err()); + assert!( + result.unwrap_err().to_string().contains("division by zero"), + "error message should mention division by zero" + ); + Ok(()) + } + + #[test] + fn test_multiplication_f64() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(3.5f64); + segment.just(2.0f64); + lookup.lookup("*", &mut segment, 2)?; + assert_eq!(segment.call0::()?, 7.0); + Ok(()) + } + + #[test] + fn test_comparison_less_than() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(10u32); + segment.just(20u32); + lookup.lookup("<", &mut segment, 2)?; + assert_eq!(segment.call0::()?, true); + Ok(()) + } + + #[test] + fn test_logical_and() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(true); + segment.just(false); + lookup.lookup("&&", &mut segment, 2)?; + assert_eq!(segment.call0::()?, false); + Ok(()) + } + + #[test] + fn test_bitwise_and() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(0b1010u32); + segment.just(0b1100u32); + lookup.lookup("&", &mut segment, 2)?; + assert_eq!(segment.call0::()?, 0b1000); + Ok(()) + } + + #[test] + fn test_unary_negation() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(42i32); + lookup.lookup("-", &mut segment, 1)?; + assert_eq!(segment.call0::()?, -42); + Ok(()) + } + + #[test] + fn test_logical_not() -> Result<()> { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(true); + lookup.lookup("!", &mut segment, 1)?; + assert_eq!(segment.call0::()?, false); + Ok(()) + } + + #[test] + fn test_unregistered_operation() { + let lookup = OpLookup::new(); + let mut segment = DynSegment::new::<()>(); + segment.just(10u32); + segment.just(20u32); + let result = lookup.lookup("unknown_op", &mut segment, 2); + assert!(result.is_err()); + } + + #[test] + fn test_custom_scope() -> Result<()> { + let mut lookup = OpLookup::new(); + + // Add a custom scope that handles "double" + lookup.push_scope(|name, segment, num_operands| { + let matches = { + let top = segment.peek_stack_infos(num_operands); + name == "double" && top.len() == 1 && top[0].type_id == TypeId::of::() + }; + if matches { + segment.op1(|a: u32| a * 2)?; + Ok(true) + } else { + Ok(false) + } + }); + + let mut segment = DynSegment::new::<()>(); + segment.just(21u32); + lookup.lookup("double", &mut segment, 1)?; + assert_eq!(segment.call0::()?, 42); + + Ok(()) + } + + #[test] + fn test_scope_override() -> Result<()> { + let mut lookup = OpLookup::new(); + + // Override addition to always return 100 + lookup.push_scope(|name, segment, num_operands| { + let matches = { + let top = segment.peek_stack_infos(num_operands); + name == "+" && top.len() == 2 && top[0].type_id == TypeId::of::() + }; + if matches { + segment.op2(|_a: u32, _b: u32| 100u32)?; + Ok(true) + } else { + Ok(false) + } + }); + + let mut segment = DynSegment::new::<()>(); + segment.just(10u32); + segment.just(20u32); + lookup.lookup("+", &mut segment, 2)?; + assert_eq!(segment.call0::()?, 100); + + Ok(()) + } + + #[test] + fn test_scope_pop() -> Result<()> { + let mut lookup = OpLookup::new(); + + lookup.push_scope(|name, segment, num_operands| { + let matches = { + let top = segment.peek_stack_infos(num_operands); + name == "+" && top.len() == 2 && top[0].type_id == TypeId::of::() + }; + if matches { + segment.op2(|_a: u32, _b: u32| 100u32)?; + Ok(true) + } else { + Ok(false) + } + }); + + // Test with override + let mut segment = DynSegment::new::<()>(); + segment.just(10u32); + segment.just(20u32); + lookup.lookup("+", &mut segment, 2)?; + assert_eq!(segment.call0::()?, 100); + + // Pop scope and test normal behavior + lookup.pop_scope(); + let mut segment = DynSegment::new::<()>(); + segment.just(10u32); + segment.just(20u32); + lookup.lookup("+", &mut segment, 2)?; + assert_eq!(segment.call0::()?, 30); + + Ok(()) + } +} diff --git a/cel-runtime/src/raw_segment.rs b/cel-runtime/src/raw_segment.rs index 474adb5..315c9f9 100644 --- a/cel-runtime/src/raw_segment.rs +++ b/cel-runtime/src/raw_segment.rs @@ -223,6 +223,41 @@ impl RawSegment { self.base_alignment = max(self.base_alignment, align_of::()); } + fn push_op2r_(&mut self) + where + F: Fn(&mut RawStack, T, U) -> Result + 'static, + T: 'static, + U: 'static, + R: 'static, + { + self.ops.push(|storage, p, stack| { + let (f, r) = unsafe { storage.next::(p) }; + let y: U = unsafe { stack.pop(PADDING1) }; + let x: T = unsafe { stack.pop(PADDING0) }; + let result = f(stack, x, y)?; + stack.push(result); + Ok(r) + }); + } + + /// Push a fallible binary operation that can manipulate the stack. + pub fn raw2(&mut self, op: F, padding0: bool, padding1: bool) + where + F: Fn(&mut RawStack, T, U) -> Result + 'static, + T: 'static, + U: 'static, + R: 'static, + { + self.push_storage(op); + match (padding0, padding1) { + (false, false) => self.push_op2r_::(), + (false, true) => self.push_op2r_::(), + (true, false) => self.push_op2r_::(), + (true, true) => self.push_op2r_::(), + } + self.base_alignment = max(self.base_alignment, align_of::()); + } + /// Pushes a ternary operation that takes three arguments of types T, U, and V and returns a /// value of type R. #[expect(clippy::many_single_char_names, reason = "patterned code")] diff --git a/cel-runtime/src/segment.rs b/cel-runtime/src/segment.rs index 91d90e9..20ffe9d 100644 --- a/cel-runtime/src/segment.rs +++ b/cel-runtime/src/segment.rs @@ -89,7 +89,7 @@ where ensure!( Stack::LENGTH == value.stack_ids.len() && TypeIdIterator::::new() - .eq(value.stack_ids.iter().map(|info| info.stack_id)), + .eq(value.stack_ids.iter().map(|info| info.type_id)), "stack type ids do not match" ); Ok(Segment { diff --git a/notes.txt b/notes.txt index 9643386..8dfaf51 100644 --- a/notes.txt +++ b/notes.txt @@ -11,6 +11,7 @@ Building the docs cargo doc --lib --no-deps --open --workspace cargo test --doc --workspace cargo clippy --workspace +cargo clippy --fix --workspace ``` Figure out if the call operators should be able to reuse the same stack. diff --git a/src/lib.rs b/src/lib.rs index 5504471..4ef1059 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,10 +3,9 @@ //! cel-rs provides a stack-based runtime for developing domain specific languages, including //! concatenative languages to describe concurrent processes. //! -//! This crate exposes three main components: +//! This crate exposes two main components: //! -//! - **cel-runtime**: The core stack-based runtime for developing domain specific languages -//! - **cel-parser**: A recursive descent parser for CEL expressions +//! - **cel-runtime**: The core stack-based runtime and CEL parser //! - **cel-rs-macros**: Procedural macros for CEL expressions //! //! # Examples @@ -30,13 +29,15 @@ //! ## Using the Parser //! //! ```rust -//! use cel_rs::cel_parser::CELParser; +//! use cel_rs::cel_runtime::{CELParser, OpLookup}; //! use proc_macro2::TokenStream; //! use std::str::FromStr; //! -//! let input = TokenStream::from_str("10 + 20").unwrap(); -//! let mut parser = CELParser::new(input.into_iter()); -//! assert!(parser.is_expression()); +//! let input = TokenStream::from_str("10").unwrap(); +//! let mut parser = CELParser::new(OpLookup::new()); +//! parser.set_tokens(input.into_iter()); +//! let result = parser.is_expression(); +//! assert!(result.is_ok()); //! ``` //! //! ## Using the Macros @@ -49,7 +50,6 @@ //! }; //! ``` -pub use cel_parser; pub use cel_rs_macros; pub use cel_runtime; @@ -59,12 +59,41 @@ pub mod runtime { pub use cel_runtime::*; } -/// Re-exports from the cel-parser crate for convenient access. +/// Re-exports for the CEL parser (part of cel-runtime). pub mod parser { - pub use cel_parser::*; + pub use cel_runtime::parser::{op_table::OpLookup, CELParser}; } /// Re-exports from the cel-rs-macros crate for convenient access. pub mod macros { pub use cel_rs_macros::*; } +#[cfg(test)] +mod tests { + + struct Experiment { + a: u32, + } + + struct Segment String> { + context: T, + f: F, + } + + impl String> Segment { + fn new(context: T, f: F) -> Self { + Self { context, f } + } + + fn call(&mut self) -> String { + (self.f)(&mut self.context) + } + } + + #[test] + fn experiment() { + let experiment = Experiment { a: 1 }; + let mut segment = Segment::new(experiment, |e| e.a.to_string()); + assert_eq!(segment.call(), "1"); + } +}