diff --git a/soteria-rust/lib/core.ml b/soteria-rust/lib/core.ml index aeee95d80..a94664432 100644 --- a/soteria-rust/lib/core.ml +++ b/soteria-rust/lib/core.ml @@ -12,13 +12,10 @@ module M (Rust_state_m : Rust_state_m.S) = struct let cmp ~signed l r = let ( < ) = if signed then ( <$@ ) else ( <@ ) in - if%sat l < r then ok U8.(-1s) - else if%sat l ==@ r then ok U8.(0s) else ok U8.(1s) - - let cmp_of_int v = - let zero = BV.zero (size_of_int v) in - if%sat v <$@ zero then ok U8.(-1s) - else if%sat v ==@ zero then ok U8.(0s) else ok U8.(1s) + let discr = + Typed.ite (l < r) U8.(-1s) (Typed.ite (l ==@ r) U8.(0s) U8.(1s)) + in + Enum (discr, []) let rec equality_check (v1 : [< T.sint | T.sptr ] Typed.t) (v2 : [< T.sint | T.sptr ] Typed.t) = @@ -168,18 +165,6 @@ module M (Rust_state_m : Rust_state_m.S) = struct ok (BV.of_bool (bop ml mr)) else ok (BV.of_bool v) else error `UBPointerComparison - | Cmp, Ptr (l, _), Ptr (r, _) -> - let* () = State.assert_ (Sptr.is_same_loc l r) `UBPointerComparison in - let* v = Sptr.distance l r in - let* cmp = cmp_of_int v in - ok cmp - | Cmp, Ptr (p, _), Int v | Cmp, Int v, Ptr (p, _) -> - if%sat v ==@ BV.usizei (Layout.size_of_uint_ty Usize) then - if%sat Sptr.is_at_null_loc p then ok U8.(0s) - else if l = Int v then ok U8.(1s) - else ok U8.(-1s) - else - Fmt.kstr not_impl "Don't know how to eval %a cmp %a" Sptr.pp p ppa v | op, l, r -> Fmt.kstr not_impl "Unexpected operation or value in eval_ptr_binop: %a, %a, %a" @@ -195,20 +180,14 @@ module M (Rust_state_m : Rust_state_m.S) = struct L.debug (fun m -> m "Transmuting %a: %a -> %a" pp_rust_val v Charon_util.pp_ty from_ty Charon_util.pp_ty to_ty); - let* state = get_state () in - let^ res = - let@ () = run ~env:() ~state in - let* { size; align; _ } = Layout.layout_of from_ty in - let* { align = align_2; _ } = Layout.layout_of to_ty in - let align = BV.max ~signed:false align align_2 in - let* ptr = State.alloc_untyped ~zeroed:false ~size ~align () in - let* () = State.store ptr from_ty v in - State.load ptr to_ty - in - match res with - | Ok (v, _) -> ok v - | Error e -> error_raw e - | Missing m -> miss m + let* { size; align; _ } = Layout.layout_of from_ty in + let* { align = align_2; _ } = Layout.layout_of to_ty in + let align = BV.max ~signed:false align align_2 in + let* ptr = State.alloc_untyped ~zeroed:false ~size ~align () in + let* () = State.store ptr from_ty v in + let* v = State.load ptr to_ty in + let+ () = State.free ptr in + v let zero_valid ~ty = let^+ res = diff --git a/soteria-rust/lib/crate.ml b/soteria-rust/lib/crate.ml index 823476705..bc7ebf951 100644 --- a/soteria-rust/lib/crate.ml +++ b/soteria-rust/lib/crate.ml @@ -68,6 +68,9 @@ let is_enum adt_id = let is_struct adt_id = match (get_adt adt_id).kind with Struct _ -> true | _ -> false +let is_union adt_id = + match (get_adt adt_id).kind with Union _ -> true | _ -> false + let as_enum adt_id = match (get_adt adt_id).kind with | Enum variants -> variants diff --git a/soteria-rust/lib/encoder.ml b/soteria-rust/lib/encoder.ml index 2b548ea53..e6488185e 100644 --- a/soteria-rust/lib/encoder.ml +++ b/soteria-rust/lib/encoder.ml @@ -98,6 +98,9 @@ module Make (Sptr : Sptr.S) = struct DecayMapMonad.Result.map (assert_or_error cond err) (fun () -> ((), state)) + let fold_iter x ~init ~f = + Monad.foldM ~bind ~return:ok ~fold:Foldable.Iter.fold x ~init ~f + module Syntax = struct let ( let* ) x f = bind x f let ( let+ ) x f = map x f @@ -112,145 +115,113 @@ module Make (Sptr : Sptr.S) = struct end end - type cval_info = rust_val * T.sint Typed.t - [@@deriving show { with_path = false }] - - (** Converts a Rust value of the given type into a list of sub values, along - with their size and offset, and whether they are interiorly mutable. *) - let rec rust_to_cvals ?offset (value : rust_val) (ty : Types.ty) : - (cval_info list, 'e, 'f) Rustsymex.Result.t = + (** Iterator over the fields and offsets of a type; for primitive types, + returns a singleton iterator for that value. *) + let iter_fields ?variant ?(meta = Thin) layout (ty : Types.ty) = + let aux ?variant fields = + Iter.mapi (fun i ty -> (ty, Fields_shape.offset_of i fields)) + @@ + match ty with + | TAdt { id = TTuple; generics = { types; _ } } -> Iter.of_list types + | TAdt + { + id = TBuiltin TArray; + generics = { types = [ ty ]; const_generics = [ len ]; _ }; + } -> + Iter.repeatz (z_of_const_generic len) ty + | TAdt { id = TBuiltin ((TStr | TSlice) as kind); generics } -> ( + let sub_ty = + match kind with + | TSlice -> List.hd generics.types + | TStr -> TLiteral (TUInt U8) + | _ -> failwith "unreachable" + in + match meta with + | Len len when Option.is_some (BV.to_z len) -> + (* TODO: strings and slices of symbolic length *) + Iter.repeatz (Option.get (BV.to_z len)) sub_ty + | Thin | Len _ | VTable _ -> + failwith "iter_fields: invalid length for slice/str") + | TAdt { id = TAdtId t_id; _ } -> ( + let type_decl = Crate.get_adt t_id in + match (type_decl.kind, variant) with + | Struct fields, _ -> + let field_tys = field_tys fields in + Iter.of_list field_tys + | Enum variants, Some variant -> + let variant = Types.VariantId.nth variants variant in + let field_tys = field_tys variant.fields in + Iter.of_list field_tys + | _ -> failwith "invalid iter_fields type_decl") + | TRef (_, pointee, _) | TRawPtr (pointee, _) -> ( + match Layout.dst_kind pointee with + | NoneKind -> failwith "invalid iter_fields: no metadata" + | LenKind -> Iter.of_list [ unit_ptr; TLiteral (TInt Isize) ] + | VTableKind -> Iter.of_list [ unit_ptr; unit_ptr ]) + | _ -> Fmt.failwith "invalid iter_fields: %a" pp_ty ty + in + match layout.fields with + | Primitive -> Iter.singleton (ty, Usize.(0s)) + | Array _ -> aux ?variant layout.fields + | Arbitrary (variant, _) -> aux ~variant layout.fields + | Enum (_, variant_layouts) -> + let variant = Option.get ~msg:"variant required for enum" variant in + let fields = variant_layouts.(Types.VariantId.to_int variant) in + aux ~variant fields + + (** [encode ?offset v ty] Converts a [Rust_val.t] of type [ty] into an + iterator over its sub values, along with their offset. Offsets all blocks + by [offset] if specified *) + let rec encode ~offset (value : rust_val) (ty : Types.ty) : + ((rust_val * T.sint Typed.t) Iter.t, 'e, 'f) Rustsymex.Result.t = let open Rustsymex in let open Syntax in let open Result in - let illegal_pair () = - L.error (fun m -> - m "Wrong pair of rust_value and Charon.ty:@.- Val: %a@.- Ty: %a" - pp_rust_val value Types.pp_ty ty); - failwith "Wrong pair of rust_value and Charon.ty" + let chain iter = + (match value with + | Tuple vals | Enum (_, vals) -> vals + | Ptr (base, VTable vt) -> [ Ptr (base, Thin); Ptr (vt, Thin) ] + | Ptr (base, Len len) -> [ Ptr (base, Thin); Int len ] + | Ptr (_, Thin) | Int _ | Float _ | ConstFn _ -> + failwith "Cannot split primitive" + | Union _ -> failwith "Cannot encode union directly") + |> Iter.combine_list iter + |> Result.fold_iter ~init:(0, Iter.empty) + ~f:(fun (i, acc) ((ty, ofs), v) -> + let offset = offset +!!@ ofs in + let++ ys = encode ~offset v ty in + (i + 1, Iter.append acc ys)) + |> (Fun.flip Result.map) snd in - let offset = Option.value ~default:Usize.(0s) offset in - let chain_cvals layout vals types = - let sub_vals = List.combinei vals types in - fold_list sub_vals ~init:[] ~f:(fun acc (i, value, ty) -> - let field_offset = Layout.Fields_shape.offset_of i layout.fields in - let offset = field_offset +!!@ offset in - let++ blocks = rust_to_cvals ~offset value ty in - acc @ blocks) - in - match (value, ty) with - (* Trait types: we resolve them early *) - | _, TTraitType (tref, name) -> - let** ty = Layout.resolve_trait_ty tref name in - rust_to_cvals ~offset value ty - (* Literals *) - | Int _, TLiteral _ -> ok [ (value, offset) ] - | Float _, TLiteral _ -> ok [ (value, offset) ] - | Ptr _, TLiteral (TInt Isize | TUInt Usize) -> ok [ (value, offset) ] - | _, TLiteral _ -> illegal_pair () - (* References / Pointers *) - | ( Ptr (ptr, meta), - TAdt { id = TBuiltin TBox; generics = { types = [ sub_ty ]; _ } } ) - | Ptr (ptr, meta), TRef (_, sub_ty, _) - | Ptr (ptr, meta), TRawPtr (sub_ty, _) -> ( - let size = BV.usizei (Layout.size_of_int_ty Isize) in - match (meta, is_dst sub_ty) with - | _, false -> ok [ (Ptr (ptr, Thin), offset) ] - | Thin, true -> failwith "Expected a fat pointer" - | Len len, true -> - ok [ (Ptr (ptr, Thin), offset); (Int len, offset +!!@ size) ] - | VTable vt, true -> - ok [ (Ptr (ptr, Thin), offset); (Ptr (vt, Thin), offset +!!@ size) ] - ) - (* Function pointer *) - | Ptr (_, Thin), TFnPtr _ -> ok [ (value, offset) ] - (* References / Pointers obtained from casting *) - | Int _, TAdt { id = TBuiltin TBox; _ } - | Int _, TRef _ - | Int _, TRawPtr _ - | Int _, TFnPtr _ -> - ok [ (value, offset) ] - | _, TAdt { id = TBuiltin TBox; _ } | _, TRawPtr _ | _, TRef _ -> - illegal_pair () - (* Tuples *) - | Tuple vs, TAdt { id = TTuple; generics = { types; _ } } -> - let** layout = layout_of ty in - chain_cvals layout vs types - | _, TAdt { id = TTuple; _ } -> illegal_pair () - (* Structs *) - | Tuple vals, TAdt { id = TAdtId t_id; _ } -> - let fields = field_tys @@ Crate.as_struct t_id in - let** layout = layout_of ty in - chain_cvals layout vals fields - (* Enums *) - | Enum (disc, vals), (TAdt { id = TAdtId t_id; _ } as ty) -> ( - let variants = Crate.as_enum t_id in - let** layout = Layout.layout_of ty in - match layout with - | { fields = Arbitrary (variant, _); _ } -> - let variant = Types.VariantId.nth variants variant in - let var_fields = field_tys variant.fields in - chain_cvals layout vals var_fields - | { fields = Enum (tag_layout, var_fields); _ } -> - let disc = - Option.get ~msg:"Discriminant not concrete" (BV.to_z disc) - in - let variant_id, variant = - Option.get ~msg:"No matching variant?" - @@ List.find_mapi - (fun i v -> - let v_disc = z_of_literal Types.(v.discriminant) in - if Z.equal disc v_disc then Some (i, v) else None) - variants - in - let var_fields = var_fields.(variant_id) in - let discriminant = - match tag_layout.tags.(variant_id) with - | None -> [] - | Some tag -> - let offset = tag_layout.offset +!!@ offset in - [ (Int tag, offset) ] - in - let var_layout = { layout with fields = var_fields } in - let++ content = - chain_cvals var_layout vals (field_tys variant.fields) - in - discriminant @ content - | _ -> Fmt.failwith "Unexpected layout for enum") - | Int value, TAdt { id = TAdtId t_id; _ } when Crate.is_enum t_id -> - let++ layout = Layout.layout_of ty in - let tag_ty = - match layout.fields with - | Enum (tag, _) -> tag.ty - | _ -> failwith "Expected enum layout" - in - let value = Typed.cast_lit tag_ty value in - [ (Int value, offset) ] - | Enum _, _ -> illegal_pair () - (* Arrays *) - | ( Tuple vals, - TAdt - { - id = TBuiltin TArray; - generics = { types = [ sub_ty ]; const_generics = [ len ]; _ }; - } ) -> - let** layout = layout_of ty in - let len = int_of_const_generic len in - if List.length vals <> len then failwith "Array length mismatch"; - chain_cvals layout vals (List.init len (fun _ -> sub_ty)) - | _, TAdt { id = TBuiltin TArray; _ } -> illegal_pair () - (* Unions *) - | Union blocks, TAdt { id = TAdtId _; _ } -> - ok (List.map (fun (v, o) -> (v, offset +!!@ o)) blocks) - | Union _, _ -> illegal_pair () - (* Static Functions (ZSTs) *) - | ConstFn _, TFnDef _ -> ok [] - | ConstFn _, _ | _, TFnDef _ -> illegal_pair () - (* Should have been handled for arrays, tuples and structs *) - | Tuple _, _ -> illegal_pair () - (* Rest *) - | _ -> - Fmt.kstr not_impl "Unhandled value/ty: %a / %a" pp_rust_val value pp_ty - ty + let** layout = Layout.layout_of ty in + if%sat layout.size ==@ Usize.(0s) then ok Iter.empty + else + match (layout.fields, value) with + | _, Union blocks -> + ok (Iter.of_list blocks |> Iter.map (fun (v, o) -> (v, offset +!!@ o))) + | Primitive, _ -> ok (Iter.singleton (value, offset)) + | Array _, _ | Arbitrary (_, _), _ -> chain (iter_fields layout ty) + | Enum (tag_layout, _), Enum (disc, _) -> ( + let adt_id, _ = TypesUtils.ty_as_custom_adt ty in + let variants = Crate.as_enum adt_id in + let variants = List.mapi (fun i v -> (i, v)) variants in + let* variant = + match_on variants ~constr:(fun (_, v) -> + BV.of_literal v.discriminant ==@ disc) + in + let* i, _ = + of_opt_not_impl "no matching variant for enum discriminant" variant + in + let variant = Types.VariantId.of_int i in + let++ fields = chain (iter_fields ~variant layout ty) in + match tag_layout.tags.(i) with + | None -> fields + | Some tag -> + let offset = tag_layout.offset +!!@ offset in + Iter.cons (Int tag, offset) fields) + | Enum _, _ -> + Fmt.kstr not_impl "encode: expected enum value for enum type %a" pp_ty + ty (** Parses the current variant of the enum at the given offset. This handles cases such as niches, where the discriminant isn't directly encoded as a @@ -262,18 +233,12 @@ module Make (Sptr : Sptr.S) = struct let* layout = layout_of ty in (* if it's a ZST, we assume it's the first variant; I don't think this is always true, e.g. enum { A(!), B }, but it's ok for now. *) - match layout with - | { fields = Arbitrary (vid, _); _ } -> ok vid - | { fields = Enum (tag_layout, _); _ } -> ( + match layout.fields with + | Arbitrary (vid, _) -> ok vid + | Enum (tag_layout, _) -> ( let offset = offset +!!@ tag_layout.offset in let* tag = query (TLiteral tag_layout.ty, offset) in - (* here we need to check and decay if it's a pointer, for niche encoding! *) - let* tag = - match tag with - | Int tag -> ok (Typed.cast_lit tag_layout.ty tag) - | Ptr (p, Thin) -> lift @@ Sptr.decay p - | _ -> Fmt.failwith "Unexpected tag: %a" pp_rust_val tag - in + let tag = as_base tag_layout.ty tag in let tags = Array.to_seqi tag_layout.tags |> List.of_seq in let* res = lift @@ -292,204 +257,105 @@ module Make (Sptr : Sptr.S) = struct in error (`UBTransmute msg) | Niche untagged, None -> ok untagged) - | _ -> failwith "Unexpected layout for enum" - - type ('e, 'fix, 'state) parser = (rust_val, 'state, 'e, 'fix) ParserMonad.t + | Array _ | Primitive -> failwith "Unexpected layout for enum" - (** Converts a Rust type into a list of types to read, along with their - offset; once these are read, symbolically decides whether we must keep - reading. [offset] is the initial offset to read from, [meta] is the - optional metadata, that originates from a fat pointer. *) - let rust_of_cvals ?(meta = Thin) ~is_valid_ptr ~offset : + (** [decode ~meta ~offset ty] Parses a rust value of type [ty] at the given + offset, using the provided metadata for DSTs, and returns the associated + [Rust_val]. This does not perform any validity checking, aside from + erroring if the type is uninhabited. *) + let decode ~meta ~offset : Types.ty -> (rust_val, 'state, 'e, 'fix) ParserMonad.t = let open ParserMonad in let open ParserMonad.Syntax in let module T = Typed.T in (* Base case, parses all types. *) - let rec aux offset ty : ('e, 'fix, 'state) parser = - match (ty : Types.ty) with - | TLiteral _ as ty -> ( - let* q_res = query (ty, offset) in - match q_res with - | (Int _ | Float _) as v -> ok v - | Ptr (ptr, Thin) -> - let+ ptr_v = lift @@ Sptr.decay ptr in - Int ptr_v - | _ -> - Fmt.kstr not_impl "Expected a base or a thin pointer, got %a" - pp_rust_val q_res) - | ( TAdt { id = TBuiltin TBox; generics = { types = [ sub_ty ]; _ } } - | TRef (_, sub_ty, _) - | TRawPtr (sub_ty, _) ) as ty -> - let must_be_valid = - match ty with - | TRef _ | TAdt { id = TBuiltin TBox; _ } -> true - | TRawPtr _ -> false - | _ -> failwith "Impossible" - in - let ptr_size = BV.usizei @@ Layout.size_of_int_ty Isize in - let meta_kind = dst_kind sub_ty in - (* Small hack; we don't want to read the metadata ! *) - let ty = - match (meta_kind, ty) with - | NoneKind, _ -> ty - | _, TRawPtr _ -> unit_ptr - | _, _ -> unit_ref - in - let* ptr = query (ty, offset) in - let* meta : Sptr.t Rust_val.meta = - match meta_kind with - | NoneKind -> ok Thin - | LenKind -> ( - let isize : Types.ty = TLiteral (TInt Isize) in - let* meta = query (isize, offset +!!@ ptr_size) in - match meta with - | Int meta -> ok (Len meta) - | _ -> not_impl "Unexpected metadata value") - | VTableKind -> ( - let* meta = query (unit_ptr, offset +!!@ ptr_size) in - match meta with - | Ptr (meta_v, Thin) -> ok (VTable meta_v) - | _ -> not_impl "Unexpected metadata value") - in - let* () = - match (must_be_valid, ptr) with - | false, _ -> ok () - | true, Ptr ptr -> ( - let* valid = - lift @@ DecayMapMonad.lift @@ is_valid_ptr ptr sub_ty - in - match (valid, meta) with - | Some err, _ -> error err - | None, Len len when must_be_valid -> - assert_or_error - (Usize.(0s) <=$@ len) - (`UBTransmute "Negative slice length") - | None, _ -> ok ()) - | true, _ -> error `UBDanglingPointer - in - let+ ptr = - match ptr with - | Ptr (ptr_v, Thin) -> ok ptr_v - | Int ptr_v -> - let ptr_v = Typed.cast_i Usize ptr_v in - ok (Sptr.null_ptr_of ptr_v) - | v -> - Fmt.kstr not_impl "Unexpected pointer value: %a" pp_rust_val v - in - Ptr (ptr, meta) - | TFnPtr _ as ty -> ( - let* boxed = query (ty, offset) in - match boxed with - | Ptr (p, _) as ptr -> - let+ () = - assert_or_error - (Typed.not (Sptr.sem_eq (Sptr.null_ptr ()) p)) - `UBDanglingPointer - in - ptr - | Int _ -> error `UBDanglingPointer - | _ -> not_impl "Expected a pointer or base") - | TAdt { id = TTuple; generics = { types; _ } } as ty -> - let* layout = layout_of ty in - let types = List.to_seq types in - aux_fields ~f:(fun fs -> Tuple fs) ~layout offset types - | TAdt { id = TAdtId t_id; _ } as ty -> ( - let type_decl = Crate.get_adt t_id in - match type_decl.kind with - | Struct fields -> - let* layout = layout_of ty in - fields - |> field_tys - |> List.to_seq - |> aux_fields ~f:(fun fs -> Tuple fs) ~layout offset - | Enum [] -> error `RefToUninhabited - | Enum variants -> aux_enum offset ty variants - | Union _ -> - let* layout = layout_of ty in - if%sat layout.size ==@ Usize.(0s) then ok (Union []) - else - (* FIXME: this isn't exactly correct; union actually doesn't copy the padding - bytes (i.e. the intersection of the padding bytes of all fields). It is - quite painful to actually calculate these padding bytes so we just copy - the whole thing for now. - See https://github.com/rust-lang/unsafe-code-guidelines/issues/518 - And a proper implementation is here: - https://github.com/minirust/minirust/blob/master/tooling/minimize/src/chunks.rs *) - let+ blocks = get_all (Typed.cast layout.size, offset) in - Union blocks - | _ -> - Fmt.kstr failwith "Unhandled ADT kind in rust_of_cvals: %a" - Types.pp_type_decl_kind type_decl.kind) - | TAdt { id = TBuiltin TArray; generics = { types; const_generics; _ } } - as ty -> - let sub_ty = List.hd types in - let len = z_of_const_generic @@ List.hd const_generics in - let* layout = layout_of ty in - let fields = Seq.init_z len (fun _ -> sub_ty) in - aux_fields ~f:(fun fs -> Tuple fs) ~layout offset fields - | TAdt { id = TBuiltin (TStr as ty); generics } - | TAdt { id = TBuiltin (TSlice as ty); generics } -> - (* We can only read a slice if we have the metadata of its length, in which case - we interpret it as an array of that length. *) - let* len = - match meta with - | Thin -> failwith "Tried reading slice without metadata" - | Len l -> ok l - | VTable ptr -> lift @@ Sptr.decay ptr - in - let* len = - of_opt_not_impl - (Fmt.str "Slice length not concrete: %a" Typed.ppa len) - (BV.to_z len) - in - let sub_ty = - if ty = TSlice then List.hd generics.types else TLiteral (TUInt U8) - in - (* FIXME: This is a bit hacky, and not performant -- instead we should try to - group the reads together, at least for primitive types. *) - let arr_ty = mk_array_ty sub_ty len in - let* layout = layout_of arr_ty in - let fields = Seq.init_z len (fun _ -> sub_ty) in - aux_fields ~f:(fun fs -> Tuple fs) ~layout offset fields - | TNever -> error `RefToUninhabited - | TTraitType (tref, name) -> - let* ty = lift_rsymex @@ Layout.resolve_trait_ty tref name in - aux offset ty - | TFnDef fnptr -> ok (ConstFn fnptr.binder_value) - | TDynTrait _ -> not_impl "Tried reading a trait object?" - | TAdt { id = TBuiltin TBox; _ } -> failwith "Invalid box" - | (TVar _ | TError _ | TPtrMetadata _) as ty -> - Fmt.kstr not_impl "Unhandled Charon.ty: %a" Types.pp_ty ty - (* Parses a sequence of fields (for structs, tuples, arrays) *) - and aux_fields ~f ~(layout : Layout.t) offset (fields : Types.ty Seq.t) : - ('e, 'fix, 'state) parser = - let base_offset = offset +!!@ (offset %@ layout.align) in - let rec mk_callback idx to_parse parsed : ('e, 'fix, 'state) parser = - match to_parse () with - | Seq.Nil -> ok (f (List.rev parsed)) - | Seq.Cons (ty, rest) -> - let field_off = Layout.Fields_shape.offset_of idx layout.fields in - let offset = base_offset +!!@ field_off in - bind (aux offset ty) (fun v -> - mk_callback (succ idx) rest (v :: parsed)) - in - mk_callback 0 fields [] - (* Parses what enum variant we're handling *) - and aux_enum offset ty variants : ('e, 'fix, 'state) parser = - let* v_id = variant_of_enum ~offset ty in + let rec aux offset ty : (rust_val, 'state, 'e, 'fix) ParserMonad.t = let* layout = layout_of ty in - let fields = Layout.Fields_shape.shape_for_variant v_id layout.fields in - let variant = Types.VariantId.nth variants v_id in - let layout = { layout with fields } in - let discr = BV.of_literal variant.discriminant in - variant.fields - |> field_tys - |> List.to_seq - |> aux_fields ~f:(fun fs -> Enum (discr, fs)) ~layout offset + match (layout.fields, ty) with + | _, TDynTrait _ -> not_impl "Tried reading a trait object?" + | _, TAdt { id = TAdtId id; _ } when Crate.is_union id -> + if%sat layout.size ==@ Usize.(0s) then ok (Union []) + else + (* FIXME: this isn't exactly correct; union actually doesn't copy the padding + bytes (i.e. the intersection of the padding bytes of all fields). It is + quite painful to actually calculate these padding bytes so we just copy + the whole thing for now. + See https://github.com/rust-lang/unsafe-code-guidelines/issues/518 + And a proper implementation is here: + https://github.com/minirust/minirust/blob/master/tooling/minimize/src/chunks.rs *) + let+ blocks = get_all (Typed.cast layout.size, offset) in + Union blocks + | Primitive, TNever -> error `RefToUninhabited + | Primitive, TFnDef fnptr -> ok (ConstFn fnptr.binder_value) + | Primitive, _ -> query (ty, offset) + | Array _, (TRawPtr (pointee, _) | TRef (_, pointee, _)) -> ( + let+ vs = iter (iter_fields ~meta layout ty) offset in + let vs = as_tuple vs in + match (dst_kind pointee, vs) with + | LenKind, [ Ptr (base, Thin); Int len ] -> Ptr (base, Len len) + | VTableKind, [ Ptr (base, Thin); Ptr (vtable, Thin) ] -> + Ptr (base, VTable vtable) + | _ -> failwith "decode: invalid metadata for pointer type") + | Array _, _ -> iter (iter_fields ~meta layout ty) offset + | Arbitrary (variant, _), _ -> ( + let+ vs = iter (iter_fields ~meta layout ty) offset in + match ty with + | TAdt { id = TAdtId t_id; _ } when Crate.is_enum t_id -> + let variants = Crate.as_enum t_id in + let variant = Types.VariantId.nth variants variant in + let fields = as_tuple vs in + let discr = BV.of_literal variant.discriminant in + Enum (discr, fields) + | _ -> vs) + | Enum _, TAdt { id = TAdtId t_id; _ } -> + let variants = Crate.as_enum t_id in + let* variant = variant_of_enum ~offset ty in + let+ fields = iter (iter_fields ~variant ~meta layout ty) offset in + let fields = as_tuple fields in + let variant = Types.VariantId.nth variants variant in + let discr = BV.of_literal variant.discriminant in + Enum (discr, fields) + | Enum _, _ -> failwith "decode: expected enum type for enum layout" + and iter fields offset = + fold_iter fields ~init:[] ~f:(fun vs (ty, o) -> + let+ v = aux (offset +!!@ o) ty in + v :: vs) + |> (Fun.flip map) (fun vs -> Tuple (List.rev vs)) in aux offset + (** Ensures this value is valid for the given type. This includes checking + pointer metadata, e.g. slice lengths and vtables. The [fake_read] function + is used to simulate reading from memory to check the validity of a pointee + type. *) + let check_valid ~fake_read v ty st = + let open Rustsymex in + let open Syntax in + let open Result in + match (v, (ty : Types.ty)) with + | Ptr ((_, meta) as p), TRef (_, pointee, _) -> ( + let** () = + match meta with + | Thin -> ok () + | Len len -> + assert_or_error + (Usize.(0s) <=$@ len) + (`UBTransmute "Negative slice length") + | VTable _ -> + (* TODO: check the vtable pointer is of the right trait kind *) + ok () + in + let* opt_err, st = fake_read p pointee st in + match opt_err with Some err -> error err | None -> ok st) + | Ptr (p, _), TFnPtr _ -> + let++ () = + assert_or_error + (Typed.not (Sptr.sem_eq (Sptr.null_ptr ()) p)) + `UBDanglingPointer + in + st + | _ -> ok st + let with_constraints ~ty v = let constraints = Typed.conj @@ Layout.constraints ty v in let msg = Fmt.str "Constraints of %a unsatisfied" pp_literal_ty ty in @@ -574,7 +440,10 @@ module Make (Sptr : Sptr.S) = struct with_constraints ~ty v | (TRawPtr _ | TRef _ | TFnPtr _), Ptr (_, Thin) -> ok v | TRef _, Int _ -> error `UBDanglingPointer - | (TRawPtr _ | TFnPtr _), Int v -> ok (Ptr (Sptr.null_ptr_of v, Thin)) + | TFnPtr _, Int v -> + if%sat v ==@ Usize.(0s) then error `UBDanglingPointer + else ok (Ptr (Sptr.null_ptr_of v, Thin)) + | TRawPtr _, Int v -> ok (Ptr (Sptr.null_ptr_of v, Thin)) | _ -> Fmt.kstr not_impl "transmute_one: unsupported %a -> %a" pp_rust_val v pp_ty to_ty diff --git a/soteria-rust/lib/interp.ml b/soteria-rust/lib/interp.ml index dc624a856..bcf4bbfe9 100644 --- a/soteria-rust/lib/interp.ml +++ b/soteria-rust/lib/interp.ml @@ -273,7 +273,8 @@ module Make (State : State_intf.S) = struct Expressions.pp_field_proj_kind kind Types.pp_field_id field Sptr.pp ptr Sptr.pp ptr'); if not @@ Layout.is_inhabited ty then error `RefToUninhabited - else ok (ptr', meta) + else if Layout.is_dst ty then ok (ptr', meta) + else ok (ptr', Thin) | PlaceProjection (base, ProjIndex (idx, from_end)) -> let* ptr, meta = resolve_place base in let len = @@ -484,17 +485,24 @@ module Make (State : State_intf.S) = struct | _ -> not_impl "Invalid type for Neg") | Cast (CastRawPtr (from_ty, to_ty)) -> ( match (from_ty, to_ty) with - | (TRef _ | TRawPtr _), TLiteral _ -> + | (TRef _ | TRawPtr _), TLiteral to_ty -> (* expose provenance *) let v, _ = as_ptr v in - let+ v' = Sptr.expose v in - Int v' + let* v' = Sptr.expose v in + Encoder.cast_literal ~from_ty:(TUInt Usize) ~to_ty v' | TLiteral _, (TRef _ | TRawPtr _) -> (* with provenance *) let v = as_base_i Usize v in let+ ptr = State.with_exposed v in Ptr ptr - | _ -> ok v) + | _, (TRef (_, to_ty, _) | TRawPtr (to_ty, _)) -> ( + match (v, Layout.is_dst to_ty) with + | Ptr (ptr, _), false -> ok (Ptr (ptr, Thin)) + | Ptr (_, Thin), true -> + not_impl "Cannot cast to fat pointer without meta" + | Ptr _, true -> ok v + | _ -> not_impl "Invalid value for CastRawPtr") + | _ -> not_impl "Invalid types for CastRawPtr") | Cast (CastTransmute (from_ty, to_ty)) -> Core.transmute ~from_ty ~to_ty v | Cast (CastScalar (from_ty, to_ty)) -> @@ -605,14 +613,10 @@ module Make (State : State_intf.S) = struct in Core.eval_checked_lit_binop op ty v1 v2 | Cmp -> - let v1, v2, ty = Typed.cast_checked2 v1 v2 in - if Typed.equal_ty ty (Typed.t_ptr ()) then - error `UBPointerComparison - else - let ty = type_of_operand e1 in - let ty = TypesUtils.ty_as_literal ty in - let+ cmp = Core.cmp ~signed:(Layout.is_signed ty) v1 v2 in - Int cmp + let v1, v2, _ = Typed.cast_checked2 v1 v2 in + let ty = type_of_operand e1 in + let ty = TypesUtils.ty_as_literal ty in + ok (Core.cmp ~signed:(Layout.is_signed ty) v1 v2) | Offset -> (* non-zero offset on integer pointer is not permitted, as these are always dangling *) @@ -732,9 +736,8 @@ module Make (State : State_intf.S) = struct let field = Types.FieldId.to_int field in let* layout = Layout.layout_of (TAdt ty) in let offset = Layout.Fields_shape.offset_of field layout.fields in - let+ op_blocks = - Encoder.rust_to_cvals ~offset value (type_of_operand op) - in + let+ op_blocks = Encoder.encode ~offset value (type_of_operand op) in + let op_blocks = Iter.to_list op_blocks in Union op_blocks (* Tuple aggregate *) | Aggregate (AggregatedAdt ({ id = TTuple; _ }, None, None), operands) -> diff --git a/soteria-rust/lib/layout.ml b/soteria-rust/lib/layout.ml index 09d2f6ef8..0cb8d58c5 100644 --- a/soteria-rust/lib/layout.ml +++ b/soteria-rust/lib/layout.ml @@ -64,7 +64,7 @@ module Fields_shape = struct | Primitive -> failwith "This layout has no fields" | Enum _ -> failwith "Can't get fields of enum; use `shape_for_variant`" | Arbitrary (_, arr) -> arr.(f) - | Array stride -> BV.usizei f *!@ stride + | Array stride -> BV.usizei f *!!@ stride let shape_for_variant variant = function | Enum (_, shapes) -> shapes.(Types.VariantId.to_int variant) @@ -74,6 +74,8 @@ module Fields_shape = struct variant end +(* TODO: size should be an [option], for unsized types *) +(* TODO: add a uninhabited flag (concrete..?) *) type t = { size : T.sint Typed.t; align : T.nonzero Typed.t; @@ -166,7 +168,10 @@ let rec layout_of (ty : Types.ty) : (t, 'e, 'f) Rustsymex.Result.t = | TRawPtr (sub_ty, _) when is_dst sub_ty -> let ptr_size = Crate.pointer_size () in - ok (mk_concrete ~size:(ptr_size * 2) ~align:ptr_size ()) + ok + (mk_concrete ~size:(ptr_size * 2) ~align:ptr_size + ~fields:(Array (BV.usizei ptr_size)) + ()) (* Refs, pointers, boxes *) | TAdt { id = TBuiltin TBox; _ } | TRef (_, _, _) | TRawPtr (_, _) -> let ptr_size = Crate.pointer_size () in diff --git a/soteria-rust/lib/rtree_block.ml b/soteria-rust/lib/rtree_block.ml index aaa44b955..7d2e781dc 100644 --- a/soteria-rust/lib/rtree_block.ml +++ b/soteria-rust/lib/rtree_block.ml @@ -306,7 +306,7 @@ module Make (Sptr : Sptr.S) = struct let get_init_leaves (ofs : [< T.sint ] Typed.t) (size : [< T.nonzero ] Typed.t) (t : t option) : - (Encoder.cval_info list * t option, 'err, 'fix) Result.t = + ((rust_val * T.sint Typed.t) list * t option, 'err, 'fix) Result.t = let ((_, bound) as range) = Range.of_low_and_size ofs (Typed.cast size) in let@ t = with_bound_and_owned_check t bound in let replace_node node = Result.ok node in diff --git a/soteria-rust/lib/rust_state_m.ml b/soteria-rust/lib/rust_state_m.ml index 6b246b917..c2518caf7 100644 --- a/soteria-rust/lib/rust_state_m.ml +++ b/soteria-rust/lib/rust_state_m.ml @@ -145,11 +145,11 @@ module type S = sig module Encoder : sig include module type of Encoder.Make (RawState.Sptr) - val rust_to_cvals : - ?offset:Typed.T.sint Typed.t -> + val encode : + offset:Typed.T.sint Typed.t -> rust_val -> Types.ty -> - (cval_info list, 'env) monad + ((rust_val * Typed.T.sint Typed.t) Iter.t, 'env) monad val cast_literal : from_ty:Values.literal_type -> @@ -364,7 +364,7 @@ module Make (State : State_intf.S) : S with module RawState = State = struct let[@inline] fake_read ptr ty = fun env state -> - let* is_valid = fake_read ptr ty state in + let* is_valid, state = fake_read ptr ty state in match is_valid with | None -> ok () env state | Some err -> error err state @@ -395,8 +395,7 @@ module Make (State : State_intf.S) : S with module RawState = State = struct module Encoder = struct include Encoder.Make (RawState.Sptr) - let[@inline] rust_to_cvals ?offset v ty = - State.lift_err (rust_to_cvals ?offset v ty) + let[@inline] encode ~offset v ty = State.lift_err (encode ~offset v ty) let[@inline] cast_literal ~from_ty ~to_ty cval = State.with_decay_map_res (cast_literal ~from_ty ~to_ty cval) diff --git a/soteria-rust/lib/rust_val.ml b/soteria-rust/lib/rust_val.ml index df324aee0..2b9a4cf2b 100644 --- a/soteria-rust/lib/rust_val.ml +++ b/soteria-rust/lib/rust_val.ml @@ -88,6 +88,12 @@ let as_base ty = function let as_base_i ty = as_base (TUInt ty) +let as_tuple = function + | Tuple vals -> vals + | v -> + Fmt.failwith "Unexpected rust_val kind, expected a tuple, got: %a" + ppa_rust_val v + let size_of = function | Int v -> Typed.size_of_int v / 8 | Float f -> Svalue.FloatPrecision.size (Typed.Float.fp_of f) / 8 diff --git a/soteria-rust/lib/sptr.ml b/soteria-rust/lib/sptr.ml index f1377a4d0..08fff5793 100644 --- a/soteria-rust/lib/sptr.ml +++ b/soteria-rust/lib/sptr.ml @@ -276,6 +276,7 @@ module ArithPtr : S with type t = arithptr_t = struct let+ loc_int, decay_map = DecayMap.decay ~expose ~size ~align loc decay_map in + L.debug (fun fmt -> fmt "Decay %a -> %a" Typed.ppa loc Typed.ppa loc_int); (loc_int +!!@ ofs, decay_map) let decay p = _decay ~expose:false p diff --git a/soteria-rust/lib/state.ml b/soteria-rust/lib/state.ml index 4398c9e0d..c65ca8614 100644 --- a/soteria-rust/lib/state.ml +++ b/soteria-rust/lib/state.ml @@ -270,15 +270,14 @@ let rec check_ptr_align ((ptr, meta) : 'a full_ptr) (ty : Types.ty) st = and load ?ignore_borrow ?(ref_checks = true) ((ptr, meta) as fptr) ty st : (Sptr.t rust_val * t, Error.t, serialized) Result.t = let** (), st = check_ptr_align fptr ty st in - let is_valid_ptr = - if ref_checks then fun ptr ty -> fake_read ptr ty st - else fun _ _ -> return None - in - let parser ~offset = Encoder.rust_of_cvals ~meta ~offset ~is_valid_ptr ty in - let++ value, st = apply_parser ?ignore_borrow ptr parser st in + let parser ~offset = Encoder.decode ~meta ~offset ty in + let** value, st = apply_parser ?ignore_borrow ptr parser st in L.debug (fun f -> f "Finished reading rust value %a" (Rust_val.pp Sptr.pp) value); - (value, st) + if ref_checks then + let++ st = Encoder.check_valid ~fake_read value ty st in + (value, st) + else Result.ok (value, st) and load_discriminant ((ptr, _) as fptr) ty st = let** (), st = check_ptr_align fptr ty st in @@ -294,17 +293,27 @@ and load_discriminant ((ptr, _) as fptr) ty st = This could be fixed by lifting all misses individually inside [handler] and [get_all] in [apply_parser], but that's kind of a mess to change and not really worth it I believe; I don't think these misses matter at all (TBD). *) -and fake_read ptr ty st = - (* FIXME: i am not certain how one checks for the validity of a DST *) - if Layout.is_dst ty || Option.is_some (Layout.as_zst ty) then return None +and fake_read ((_, meta) as ptr) ty st = + let can_check_dst = + match meta with + | Thin -> true + | Len l -> + (* TODO: we don't support symbolic slices *) + Option.is_some (Typed.BitVec.to_z l) + | VTable _ -> + (* FIXME: i am not certain how one checks for the validity of a &dyn *) + false + in + if (not can_check_dst) || Option.is_some (Layout.as_zst ty) then + return (None, st) else ( L.debug (fun m -> m "Checking validity of %a for %a" (pp_full_ptr Sptr.pp) ptr Charon_util.pp_ty ty); let+ res = load ~ignore_borrow:true ~ref_checks:false ptr ty st in match res with - | Ok _ -> None - | Error e -> Some e + | Ok (_, st) -> (None, st) + | Error e -> (Some e, st) | Missing _ -> failwith "Miss in fake_read") let check_ptr_align ptr ty st = @@ -340,16 +349,16 @@ let tb_load ((ptr : Sptr.t), _) ty st = Tree_block.tb_access ofs size tag tb block) let store ((ptr, _) as fptr) ty sval st = - let** parts = lift_err st @@ Encoder.rust_to_cvals sval ty in - if List.is_empty parts then Result.ok ((), st) + let** parts = lift_err st @@ Encoder.encode ~offset:Usize.(0s) sval ty in + if Iter.is_empty parts then Result.ok ((), st) else let** (), st = check_ptr_align fptr ty st in let@ () = with_error_loc_as_call_trace st in let@ () = with_loc_err () in - L.debug (fun f -> + (* L.debug (fun f -> f "Parsed to parts [%a]" Fmt.(list ~sep:comma Encoder.pp_cval_info) - parts); + parts); *) log "store" ptr st; let** size = Layout.size_of ty in let@ ofs, block = with_ptr ptr st in @@ -359,8 +368,10 @@ let store ((ptr, _) as fptr) ty sval st = (* We uninitialise the whole range before writing, to ensure padding bytes are copied if there are any. *) let** (), block = Tree_block.uninit_range ofs size block in - Result.fold_list parts ~init:((), block) + Result.fold_iter parts ~init:((), block) ~f:(fun ((), block) (value, offset) -> + (* L.warn (fun f -> + f "Storing part %a: %a" Typed.ppa offset (pp_rust_val Sptr.pp) value); *) Tree_block.store (offset +!!@ ofs) value ptr.tag tb block) let copy_nonoverlapping ~dst:(dst, _) ~src:(src, _) ~size st = diff --git a/soteria-rust/lib/state_intf.ml b/soteria-rust/lib/state_intf.ml index a2c3d4e09..f34c0984c 100644 --- a/soteria-rust/lib/state_intf.ml +++ b/soteria-rust/lib/state_intf.ml @@ -51,7 +51,7 @@ module type S = sig full_ptr list ret val free : full_ptr -> t -> unit ret - val fake_read : full_ptr -> Types.ty -> t -> Error.t option Rustsymex.t + val fake_read : full_ptr -> Types.ty -> t -> (Error.t option * t) Rustsymex.t val check_ptr_align : full_ptr -> Types.ty -> t -> unit ret val copy_nonoverlapping : diff --git a/soteria-rust/scripts/cliopts.py b/soteria-rust/scripts/cliopts.py index eac29347b..ac92c6d3c 100644 --- a/soteria-rust/scripts/cliopts.py +++ b/soteria-rust/scripts/cliopts.py @@ -31,6 +31,7 @@ class CliOpts(TypedDict): cmd: Cmd tool: ToolName tool_cmd: list[str] + cli_extra_flags: list[str] filters: list[str] exclusions: list[str] tag: Optional[str] @@ -60,6 +61,7 @@ def parse_flags() -> CliOpts: "cmd": cast(Cmd, None), "tool": "Rusteria", "tool_cmd": [], + "cli_extra_flags": [], "filters": [], "exclusions": [], "tag": None, @@ -186,6 +188,7 @@ def pop(): opts = opts_for_rusteria(opts, force_obol=(not with_charon)) opts["tool_cmd"] += cmd_flags + opts["cli_extra_flags"] = cmd_flags return opts diff --git a/soteria-rust/scripts/test.py b/soteria-rust/scripts/test.py index 18daf6d64..369a53c95 100755 --- a/soteria-rust/scripts/test.py +++ b/soteria-rust/scripts/test.py @@ -356,7 +356,7 @@ def run_benchmark(opts: CliOpts): test_conf = callback(opts) if len(test_conf["tests"]) == 0: continue - cmd = opts["tool_cmd"] + test_conf["args"] + cmd = opts["tool_cmd"] + test_conf["args"] + opts["cli_extra_flags"] pprint( f"{CYAN}{BOLD}==>{RESET} Running benchmark {BOLD}{test_conf['name']}{RESET} with {BOLD}{opts['tool']}", ) diff --git a/soteria-rust/test/cram/kani.t/run.t b/soteria-rust/test/cram/kani.t/run.t index a6c44f293..4f0d60fb6 100644 --- a/soteria-rust/test/cram/kani.t/run.t +++ b/soteria-rust/test/cram/kani.t/run.t @@ -100,95 +100,65 @@ Test kani::vec::any_vec (0x0000000000000000 == V|1|) /\ (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) PC 2: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x000000000000000f) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x000000000000000f) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffc2) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffc2) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x000000000000000f) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffc2) PC 3: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x000000000000000e) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x000000000000000e) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffc6) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffc6) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x000000000000000e) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffc6) PC 4: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x000000000000000d) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x000000000000000d) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffca) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffca) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x000000000000000d) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffca) PC 5: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x000000000000000c) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x000000000000000c) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffce) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffce) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x000000000000000c) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffce) PC 6: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x000000000000000b) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x000000000000000b) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffd2) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffd2) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x000000000000000b) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffd2) PC 7: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x000000000000000a) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x000000000000000a) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffd6) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffd6) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x000000000000000a) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffd6) PC 8: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x0000000000000009) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x0000000000000009) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffda) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffda) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x0000000000000009) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffda) PC 9: (extract[0-1](V|18|) == 0b00) /\ (0x0000000000000008 == V|1|) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (0x0000000000000008 == V|1|) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffde) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffde) + (0b00 == extract[0-1](V|19|)) /\ (0x0000000000000008 == V|1|) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffde) PC 10: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x0000000000000007) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x0000000000000007) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffe2) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffe2) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x0000000000000007) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffe2) PC 11: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x0000000000000006) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x0000000000000006) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffe6) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffe6) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x0000000000000006) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffe6) PC 12: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x0000000000000005) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x0000000000000005) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffea) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffea) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x0000000000000005) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffea) PC 13: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x0000000000000004) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x0000000000000004) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7fffffffffffffee) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7fffffffffffffee) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x0000000000000004) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7fffffffffffffee) PC 14: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x0000000000000003) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x0000000000000003) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7ffffffffffffff2) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7ffffffffffffff2) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x0000000000000003) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7ffffffffffffff2) PC 15: (extract[0-1](V|18|) == 0b00) /\ (V|1| == 0x0000000000000002) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (V|1| == 0x0000000000000002) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7ffffffffffffff6) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7ffffffffffffff6) + (0b00 == extract[0-1](V|19|)) /\ (V|1| == 0x0000000000000002) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7ffffffffffffff6) PC 16: (extract[0-1](V|18|) == 0b00) /\ (0x0000000000000001 == V|1|) /\ - (0b00 == extract[0-1](V|19|)) /\ (0b00 == extract[0-1](V|20|)) /\ - (0x0000000000000001 == V|1|) /\ (0x0000000000000001 <=u V|18|) /\ - (V|18| <=u 0x7fffffffffffffbe) /\ (0x0000000000000001 <=u V|19|) /\ - (V|19| <=u 0x7ffffffffffffffa) /\ (0x0000000000000001 <=u V|20|) /\ - (V|20| <=u 0x7ffffffffffffffa) + (0b00 == extract[0-1](V|19|)) /\ (0x0000000000000001 == V|1|) /\ + (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) /\ + (0x0000000000000001 <=u V|19|) /\ (V|19| <=u 0x7ffffffffffffffa) PC 17: (extract[0-1](V|18|) == 0b00) /\ (0x0000000000000010 == V|1|) /\ (0x0000000000000010 == V|1|) /\ (0x0000000000000001 <=u V|18|) /\ (V|18| <=u 0x7fffffffffffffbe) diff --git a/soteria/lib/soteria_std/iter_soteria.ml b/soteria/lib/soteria_std/iter_soteria.ml new file mode 100644 index 000000000..8405bf845 --- /dev/null +++ b/soteria/lib/soteria_std/iter_soteria.ml @@ -0,0 +1,38 @@ +include Iter + +let rec of_list_combine l1 l2 k = + match (l1, l2) with + | [], [] -> () + | [ a1 ], [ a2 ] -> k (a1, a2) + | a1 :: b1 :: l1, a2 :: b2 :: l2 -> + k (a1, a2); + k (b1, b2); + of_list_combine l1 l2 k + | _, _ -> invalid_arg "Iter.of_list_combine" + +(** [combine_list i l] will combine the iterator [i] with the list [l], + returning a new iterator over the values of both [i] and [l]. If [l] is + smaller than [i], will raise a [Invalid_argument] at the end of the + iterator. *) +let[@inline] combine_list (i : 'a t) (l : 'b list) k = + let l = ref l in + i (fun x -> + match !l with + | [] -> invalid_arg "Iter.combine_list" + | a :: b -> + l := b; + k (x, a)) + +let[@inline] repeati n x k = + let i = ref 0 in + while !i < n do + incr i; + k x + done + +let[@inline] repeatz n x k = + let i = ref Z.zero in + while Z.lt !i n do + i := Z.succ !i; + k x + done diff --git a/soteria/lib/soteria_std/soteria_std.ml b/soteria/lib/soteria_std/soteria_std.ml index 8170035da..adec2b7d3 100644 --- a/soteria/lib/soteria_std/soteria_std.ml +++ b/soteria/lib/soteria_std/soteria_std.ml @@ -9,6 +9,7 @@ module Graph = Graph module Hashset = Hashset module Hashtbl = Hashtbl module Int = Int +module Iter = Iter_soteria module List = List module Map = Map module Monad = Monad diff --git a/soteria/tests/soteria_std/iter_tests.ml b/soteria/tests/soteria_std/iter_tests.ml new file mode 100644 index 000000000..4ccd3670d --- /dev/null +++ b/soteria/tests/soteria_std/iter_tests.ml @@ -0,0 +1,87 @@ +open Test_register + +let register = register "Iter" + +open Iter + +let iter_testable : int Iter.t Alcotest.testable = + (module struct + type t = int Iter.t + + let pp = Fmt.(iter ~sep:comma Iter.iter int) + + let equal l r = + let l = to_list l in + let r = to_list r in + List.equal Int.equal l r + end) + +let check_iter = Alcotest.(check iter_testable) "iterators are equal" + +let of_list_combine_test = + let@ () = register "of_list_combine" in + let l1 = [ 1; 2; 3; 4 ] in + let l2 = [ 10; 20; 30; 40 ] in + let expected = Iter.of_list [ 11; 22; 33; 44 ] in + let result = Iter.of_list_combine l1 l2 |> Iter.map (fun (a, b) -> a + b) in + check_iter result expected + +let of_list_combine_empty_test = + let@ () = register "of_list_combine_empty" in + let l1 = [] in + let l2 = [] in + let expected = Iter.of_list [] in + let result = Iter.of_list_combine l1 l2 |> Iter.map (fun (a, b) -> a + b) in + check_iter result expected + +let of_list_combine_one_elem_test = + let@ () = register "of_list_combine_one_elem" in + let l1 = [ 7 ] in + let l2 = [ 3 ] in + let expected = Iter.of_list [ 10 ] in + let result = Iter.of_list_combine l1 l2 |> Iter.map (fun (a, b) -> a + b) in + check_iter result expected + +let of_list_combine_mismatched_length_test = + let@ () = register "of_list_combine_mismatched_length" in + let l1 = [ 1; 2 ] in + let l2 = [ 10 ] in + Alcotest.check_raises "lists of mismatched length" + (Invalid_argument "Iter.of_list_combine") (fun () -> + ignore + (Iter.of_list_combine l1 l2 + |> Iter.map (fun (a, b) -> a + b) + |> Iter.to_list)) + +let join_list_n_test = + let@ () = register "join_list_n" in + let i = Iter.of_list [ 1; 2; 3; 4 ] in + let l = [ 10; 20; 30; 40 ] in + let expected = Iter.of_list [ 11; 22; 33; 44 ] in + let result = Iter.combine_list i l |> Iter.map (fun (a, b) -> a + b) in + check_iter result expected + +let join_list_0_test = + let@ () = register "join_list_0" in + let i0 = Iter.of_list [] in + let l0 = [] in + let expected0 = Iter.of_list [] in + let result0 = Iter.combine_list i0 l0 |> Iter.map (fun (a, b) -> a + b) in + check_iter result0 expected0 + +let join_list_1_test = + let@ () = register "join_list_1" in + let i1 = Iter.of_list [ 5 ] in + let l1 = [ 100 ] in + let expected1 = Iter.of_list [ 105 ] in + let result1 = Iter.combine_list i1 l1 |> Iter.map (fun (a, b) -> a + b) in + check_iter result1 expected1 + +let join_list_shorter_list_test = + let@ () = register "join_list_shorter_list_test" in + let i = Iter.of_list [ 1; 2; 3 ] in + let l = [ 10; 20 ] in + Alcotest.check_raises "list shorter than iterator" + (Invalid_argument "Iter.combine_list") (fun () -> + ignore + (Iter.combine_list i l |> Iter.map (fun (a, b) -> a + b) |> Iter.to_list)) diff --git a/soteria/tests/soteria_std/soteria_std_tests.ml b/soteria/tests/soteria_std/soteria_std_tests.ml index 7c3e0ea8b..d4e8cd878 100644 --- a/soteria/tests/soteria_std/soteria_std_tests.ml +++ b/soteria/tests/soteria_std/soteria_std_tests.ml @@ -1,5 +1,6 @@ (* Make sure modules are loaded and tests are registered *) module _ = List_tests module _ = Graph_tests +module _ = Iter_tests let () = Test_register.run_all ()