diff --git a/crates/mun_codegen/Cargo.toml b/crates/mun_codegen/Cargo.toml index 696d1b004..725dc65a3 100644 --- a/crates/mun_codegen/Cargo.toml +++ b/crates/mun_codegen/Cargo.toml @@ -14,6 +14,7 @@ categories = ["Game development", "Mun"] [dependencies] abi = { version = "=0.2.0", path = "../mun_abi", package = "mun_abi" } +bytemuck = "1.4.1" hir = { version = "=0.2.0", path = "../mun_hir", package = "mun_hir" } itertools = "0.9.0" mun_codegen_macros = { path = "../mun_codegen_macros", package = "mun_codegen_macros" } diff --git a/crates/mun_codegen/src/ir/types.rs b/crates/mun_codegen/src/ir/types.rs index 324ee3ef5..8caab3ce2 100644 --- a/crates/mun_codegen/src/ir/types.rs +++ b/crates/mun_codegen/src/ir/types.rs @@ -1,4 +1,4 @@ -use crate::value::{AsValue, IrValueContext, SizedValueType, TransparentValue, Value}; +use crate::value::{AsValue, BytesOrPtr, IrValueContext, SizedValueType, TransparentValue, Value}; use itertools::Itertools; use mun_codegen_macros::AsValue; @@ -8,6 +8,10 @@ impl<'ink> TransparentValue<'ink> for abi::Guid { fn as_target_value(&self, context: &IrValueContext<'ink, '_, '_>) -> Value<'ink, Self::Target> { self.0.as_value(context) } + + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![BytesOrPtr::Bytes(self.0.to_vec())] + } } impl<'ink> TransparentValue<'ink> for abi::Privacy { @@ -16,6 +20,10 @@ impl<'ink> TransparentValue<'ink> for abi::Privacy { fn as_target_value(&self, context: &IrValueContext<'ink, '_, '_>) -> Value<'ink, Self::Target> { (*self as u8).as_value(context) } + + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![BytesOrPtr::Bytes(vec![*self as u8])] + } } impl<'ink> TransparentValue<'ink> for abi::TypeGroup { @@ -24,6 +32,10 @@ impl<'ink> TransparentValue<'ink> for abi::TypeGroup { fn as_target_value(&self, context: &IrValueContext<'ink, '_, '_>) -> Value<'ink, Self::Target> { (*self as u8).as_value(context) } + + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![BytesOrPtr::Bytes(vec![*self as u8])] + } } impl<'ink> TransparentValue<'ink> for abi::StructMemoryKind { @@ -32,6 +44,10 @@ impl<'ink> TransparentValue<'ink> for abi::StructMemoryKind { fn as_target_value(&self, context: &IrValueContext<'ink, '_, '_>) -> Value<'ink, Self::Target> { (self.clone() as u8).as_value(context) } + + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![BytesOrPtr::Bytes(vec![self.clone() as u8])] + } } #[derive(AsValue)] diff --git a/crates/mun_codegen/src/value/array_value.rs b/crates/mun_codegen/src/value/array_value.rs index 3f4efca4d..756d0d66c 100644 --- a/crates/mun_codegen/src/value/array_value.rs +++ b/crates/mun_codegen/src/value/array_value.rs @@ -1,10 +1,12 @@ -use crate::value::{ - AddressableType, AsValue, ConcreteValueType, IrTypeContext, IrValueContext, PointerValueType, - SizedValueType, TypeValue, Value, ValueType, +use super::{ + AddressableType, AsValue, ConcreteValueType, HasConstValue, IrTypeContext, IrValueContext, + PointerValueType, SizedValueType, TypeValue, Value, ValueType, +}; +use inkwell::{ + types::{BasicType, PointerType}, + values::PointerValue, + AddressSpace, }; -use inkwell::types::{BasicType, PointerType}; -use inkwell::values::PointerValue; -use inkwell::AddressSpace; impl<'ink, T: ConcreteValueType<'ink>> ConcreteValueType<'ink> for [T] { type Value = inkwell::values::ArrayValue<'ink>; @@ -91,6 +93,12 @@ impl_array!( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 24, 32, 36, 0x40, 0x80, 0x100 ); +impl<'ink, T: ConcreteValueType<'ink> + HasConstValue> HasConstValue for &[T] { + fn has_const_value() -> bool { + T::has_const_value() + } +} + impl<'ink, E: SizedValueType<'ink>, T: AsValue<'ink, E>> AsValue<'ink, [E]> for &[T] where E::Value: ConstArrayValue<'ink>, diff --git a/crates/mun_codegen/src/value/float_value.rs b/crates/mun_codegen/src/value/float_value.rs index 48e2f8125..7a7875cbe 100644 --- a/crates/mun_codegen/src/value/float_value.rs +++ b/crates/mun_codegen/src/value/float_value.rs @@ -1,8 +1,7 @@ use super::{ - AsValue, ConcreteValueType, IrTypeContext, IrValueContext, PointerValueType, SizedValueType, - Value, + AddressableType, AsBytesAndPtrs, AsValue, BytesOrPtr, ConcreteValueType, HasConstValue, + IrTypeContext, IrValueContext, PointerValueType, SizedValueType, Value, }; -use crate::value::AddressableType; use inkwell::{types::PointerType, AddressSpace}; impl<'ink> ConcreteValueType<'ink> for f32 { @@ -44,6 +43,18 @@ impl<'ink> PointerValueType<'ink> for f64 { impl<'ink> AddressableType<'ink, f32> for f32 {} impl<'ink> AddressableType<'ink, f64> for f64 {} +impl HasConstValue for f32 { + fn has_const_value() -> bool { + true + } +} + +impl HasConstValue for f64 { + fn has_const_value() -> bool { + true + } +} + impl<'ink> AsValue<'ink, f32> for f32 { fn as_value(&self, context: &IrValueContext<'ink, '_, '_>) -> Value<'ink, f32> { Value::from_raw( @@ -51,6 +62,7 @@ impl<'ink> AsValue<'ink, f32> for f32 { ) } } + impl<'ink> AsValue<'ink, f64> for f64 { fn as_value(&self, context: &IrValueContext<'ink, '_, '_>) -> Value<'ink, f64> { Value::from_raw( @@ -58,3 +70,15 @@ impl<'ink> AsValue<'ink, f64> for f64 { ) } } + +impl<'ink> AsBytesAndPtrs<'ink> for f32 { + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![bytemuck::cast_ref::(self).to_vec().into()] + } +} + +impl<'ink> AsBytesAndPtrs<'ink> for f64 { + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![bytemuck::cast_ref::(self).to_vec().into()] + } +} diff --git a/crates/mun_codegen/src/value/int_value.rs b/crates/mun_codegen/src/value/int_value.rs index 9d93e9fee..d352c67e5 100644 --- a/crates/mun_codegen/src/value/int_value.rs +++ b/crates/mun_codegen/src/value/int_value.rs @@ -1,6 +1,6 @@ use super::{ - AddressableType, AsValue, ConcreteValueType, IrTypeContext, IrValueContext, PointerValueType, - SizedValueType, Value, + AddressableType, AsBytesAndPtrs, AsValue, BytesOrPtr, ConcreteValueType, HasConstValue, + IrTypeContext, IrValueContext, PointerValueType, SizedValueType, Value, }; use inkwell::AddressSpace; @@ -24,6 +24,12 @@ macro_rules! impl_as_int_ir_value { } impl<'ink> AddressableType<'ink, $ty> for $ty {} + + impl HasConstValue for $ty { + fn has_const_value() -> bool { + true + } + } )* } } @@ -110,3 +116,51 @@ impl<'ink> AsValue<'ink, i64> for i64 { ) } } + +impl<'ink> AsBytesAndPtrs<'ink> for u8 { + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![bytemuck::cast_ref::(self).to_vec().into()] + } +} + +impl<'ink> AsBytesAndPtrs<'ink> for u16 { + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![bytemuck::cast_ref::(self).to_vec().into()] + } +} + +impl<'ink> AsBytesAndPtrs<'ink> for u32 { + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![bytemuck::cast_ref::(self).to_vec().into()] + } +} + +impl<'ink> AsBytesAndPtrs<'ink> for u64 { + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![bytemuck::cast_ref::(self).to_vec().into()] + } +} + +impl<'ink> AsBytesAndPtrs<'ink> for i8 { + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![bytemuck::cast_ref::(self).to_vec().into()] + } +} + +impl<'ink> AsBytesAndPtrs<'ink> for i16 { + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![bytemuck::cast_ref::(self).to_vec().into()] + } +} + +impl<'ink> AsBytesAndPtrs<'ink> for i32 { + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![bytemuck::cast_ref::(self).to_vec().into()] + } +} + +impl<'ink> AsBytesAndPtrs<'ink> for i64 { + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![bytemuck::cast_ref::(self).to_vec().into()] + } +} diff --git a/crates/mun_codegen/src/value/mod.rs b/crates/mun_codegen/src/value/mod.rs index 0ff0ca748..ed672b8a3 100644 --- a/crates/mun_codegen/src/value/mod.rs +++ b/crates/mun_codegen/src/value/mod.rs @@ -37,7 +37,7 @@ use std::hash::Hash; /// is returned. This allows transparent composition. e.g.: /// /// ```rust -/// # use mun_codegen::value::{AsValue, IrValueContext, TransparentValue, Value}; +/// # use mun_codegen::value::{AsValue, BytesOrPtr, IrValueContext, TransparentValue, Value}; /// struct Foo { /// value: u32, /// bar: f32, @@ -49,6 +49,13 @@ use std::hash::Hash; /// fn as_target_value(&self, context: &IrValueContext<'ink, '_, '_>) -> Value<'ink, Self::Target> { /// (self.value, self.bar).as_value(context) /// } +/// +/// fn as_bytes_and_ptrs(&self, context: &IrValueContext<'ink, '_, '_>) -> Vec> { +/// vec![ +/// bytemuck::cast_ref::(&self.value).to_vec().into(), +/// bytemuck::cast_ref::(&self.bar).to_vec().into(), +/// ] +/// } /// } /// ``` /// @@ -87,6 +94,42 @@ pub trait TransparentValue<'ink> { /// Converts the instance to the target value fn as_target_value(&self, context: &IrValueContext<'ink, '_, '_>) -> Value<'ink, Self::Target>; + + /// Converts the instance to bytes and pointers. + fn as_bytes_and_ptrs(&self, context: &IrValueContext<'ink, '_, '_>) -> Vec>; +} + +/// Contains either a value converted to bytes or a pointer to the value. +/// +/// This is used for generating constant enum types. +#[derive(Clone, Debug)] +pub enum BytesOrPtr<'ink> { + Bytes(Vec), + UntypedPtr(PointerValue<'ink>), +} + +impl<'ink> From> for BytesOrPtr<'ink> { + fn from(bytes: Vec) -> Self { + BytesOrPtr::Bytes(bytes) + } +} + +impl<'ink> From> for BytesOrPtr<'ink> { + fn from(ptr: PointerValue<'ink>) -> Self { + BytesOrPtr::UntypedPtr(ptr) + } +} + +/// Converts a value to its raw byte representation, while leaving pointers intact. +pub trait AsBytesAndPtrs<'ink> { + /// Converts the instance to bytes and pointers. + fn as_bytes_and_ptrs(&self, context: &IrValueContext<'ink, '_, '_>) -> Vec>; +} + +/// Signals whether the instance can construct aa matching LLVM constant IR value. +pub trait HasConstValue { + /// Returns whether the instance can be converted into an LLVM IR value. + fn has_const_value() -> bool; } /// The context in which an `IrType` operates. @@ -389,6 +432,15 @@ impl< } } +impl<'ink, T> HasConstValue for T +where + T: TransparentValue<'ink>, +{ + fn has_const_value() -> bool { + true + } +} + // Transparent values can also be represented as `Value`. impl<'ink, T> AsValue<'ink, T> for T where diff --git a/crates/mun_codegen/src/value/pointer_value.rs b/crates/mun_codegen/src/value/pointer_value.rs index c42d99a3e..5bcfcaffd 100644 --- a/crates/mun_codegen/src/value/pointer_value.rs +++ b/crates/mun_codegen/src/value/pointer_value.rs @@ -1,13 +1,13 @@ -use crate::value::{ - AddressableType, ConcreteValueType, IrTypeContext, IrValueContext, PointerValueType, - SizedValueType, Value, +use super::{ + AddressableType, AsBytesAndPtrs, BytesOrPtr, ConcreteValueType, HasConstValue, IrTypeContext, + IrValueContext, PointerValueType, SizedValueType, Value, }; -use inkwell::types::PointerType; -use inkwell::AddressSpace; +use inkwell::{types::PointerType, AddressSpace}; impl<'ink, T: PointerValueType<'ink>> ConcreteValueType<'ink> for *const T { type Value = inkwell::values::PointerValue<'ink>; } + impl<'ink, T: PointerValueType<'ink>> ConcreteValueType<'ink> for *mut T { type Value = inkwell::values::PointerValue<'ink>; } @@ -22,6 +22,7 @@ impl<'ink, T: PointerValueType<'ink>> SizedValueType<'ink> for *mut T { T::get_ptr_type(context, None) } } + impl<'ink, T: PointerValueType<'ink>> PointerValueType<'ink> for *mut T { fn get_ptr_type( context: &IrTypeContext<'ink, '_>, @@ -30,6 +31,7 @@ impl<'ink, T: PointerValueType<'ink>> PointerValueType<'ink> for *mut T { Self::get_ir_type(context).ptr_type(address_space.unwrap_or(AddressSpace::Generic)) } } + impl<'ink, T: PointerValueType<'ink>> PointerValueType<'ink> for *const T { fn get_ptr_type( context: &IrTypeContext<'ink, '_>, @@ -49,3 +51,21 @@ impl<'ink, T: SizedValueType<'ink, Value = inkwell::values::PointerValue<'ink>>> impl<'ink, T> AddressableType<'ink, *const T> for *const T where *const T: ConcreteValueType<'ink> {} impl<'ink, T> AddressableType<'ink, *mut T> for *mut T where *mut T: ConcreteValueType<'ink> {} + +impl<'ink, T> HasConstValue for Value<'ink, T> +where + T: SizedValueType<'ink, Value = inkwell::values::PointerValue<'ink>>, +{ + fn has_const_value() -> bool { + true + } +} + +impl<'ink, T> AsBytesAndPtrs<'ink> for Value<'ink, T> +where + T: SizedValueType<'ink, Value = inkwell::values::PointerValue<'ink>>, +{ + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + vec![BytesOrPtr::UntypedPtr(self.value)] + } +} diff --git a/crates/mun_codegen/src/value/string.rs b/crates/mun_codegen/src/value/string.rs index e53086c8d..616cc0ee1 100644 --- a/crates/mun_codegen/src/value/string.rs +++ b/crates/mun_codegen/src/value/string.rs @@ -1,4 +1,4 @@ -use super::{AsValue, Global, IrValueContext, TransparentValue, Value}; +use super::{AsValue, BytesOrPtr, Global, IrValueContext, TransparentValue, Value}; use std::ffi::{CStr, CString}; /// Enables internalizing certain data structures like strings. @@ -57,6 +57,12 @@ impl<'ink> TransparentValue<'ink> for CString { fn as_target_value(&self, context: &IrValueContext<'ink, '_, '_>) -> Value<'ink, Self::Target> { self.as_bytes_with_nul().as_value(context) } + + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + unreachable!( + "`as_bytes_and_ptrs` should never be called on a `String`, as it cannot be a member of an enum." + ) + } } impl<'ink> TransparentValue<'ink> for String { @@ -65,4 +71,10 @@ impl<'ink> TransparentValue<'ink> for String { fn as_target_value(&self, context: &IrValueContext<'ink, '_, '_>) -> Value<'ink, Self::Target> { self.as_bytes().as_value(context) } + + fn as_bytes_and_ptrs(&self, _: &IrValueContext<'ink, '_, '_>) -> Vec> { + unreachable!( + "`as_bytes_and_ptrs` should never be called on a `String`, as it cannot be a member of an enum." + ); + } } diff --git a/crates/mun_codegen_macros/Cargo.toml b/crates/mun_codegen_macros/Cargo.toml index 97b96ac1e..090ce96be 100644 --- a/crates/mun_codegen_macros/Cargo.toml +++ b/crates/mun_codegen_macros/Cargo.toml @@ -8,5 +8,6 @@ edition = "2018" proc-macro = true [dependencies] -syn="1.0" -quote="1.0" +proc-macro2 = "1.0" +quote = "1.0" +syn = "1.0" diff --git a/crates/mun_codegen_macros/src/lib.rs b/crates/mun_codegen_macros/src/lib.rs index 7501d9331..8e80be42f 100644 --- a/crates/mun_codegen_macros/src/lib.rs +++ b/crates/mun_codegen_macros/src/lib.rs @@ -1,8 +1,9 @@ #![cfg(not(tarpaulin_include))] use proc_macro::TokenStream; +use proc_macro2::Span; use quote::quote; -use syn::{parse_macro_input, Data, DeriveInput}; +use syn::{parse_macro_input, Data, DeriveInput, Ident, Index}; /// This procedural macro implements the `AsValue` trait as well as several required other traits. /// All of these traits enable creating an `inkwell::values::StructValue` from a generic struct, as @@ -14,7 +15,7 @@ pub fn as_value_derive(input: TokenStream) -> TokenStream { // Get the typename of the struct we're working with let ident = { - let ident = derive_input.ident; + let ident = &derive_input.ident; let generics = derive_input.generics; quote! { #ident #generics @@ -23,6 +24,17 @@ pub fn as_value_derive(input: TokenStream) -> TokenStream { match derive_input.data { Data::Struct(struct_data) => { + // Generate a list of functions that return `false` if the struct field does not have an + // equivalent constant IR value. + let field_has_const_values = struct_data.fields.iter().map(|f| { + let ty = &f.ty; + quote! { + if !<#ty>::has_const_value() { + return false; + } + } + }); + // Generate a list of struct fields' paddings. // // Expects: @@ -72,6 +84,7 @@ pub fn as_value_derive(input: TokenStream) -> TokenStream { }; let field_padding_values = field_padding_types.clone(); + let field_padding_bytes = field_padding_types.clone(); // Generate a list of where clauses that ensure that we can cast each field to an // `inkwell::types::BasicTypeEnum` @@ -85,6 +98,7 @@ pub fn as_value_derive(input: TokenStream) -> TokenStream { // Generate a list of where clauses that ensure that we can cast each field to an // `inkwell::values::BasicTypeValue` let field_types_values = struct_data.fields.iter().enumerate().map(|(idx, f)| { + let idx = Index::from(idx); let name = f.ident.as_ref().map(|i| quote! { #i }).unwrap_or_else(|| quote! { #idx }); quote! { { @@ -95,6 +109,36 @@ pub fn as_value_derive(input: TokenStream) -> TokenStream { } }); + // Generate a list of bytes and `inkwell::values::PointerValue`s for each field. + // + // Expects: + // - type_context: &IrTypeContext + // - fn padded_size(align: usize, data_size: usize) -> usize + // - field_padding: Vec + let field_bytes_and_ptrs = { + let field_bytes_and_ptrs = struct_data.fields.iter().enumerate().map(|(idx, f)| { + let idx = Index::from(idx); + let name = f + .ident + .as_ref() + .map(|i| quote! { #i }) + .unwrap_or_else(|| quote! { #idx }); + quote! { + self. #name .as_bytes_and_ptrs(context) + } + }); + + quote! {{ + let field_bytes_and_ptrs = vec![ #(#field_bytes_and_ptrs),* ]; + field_padding + .into_iter() + .map(|p| vec![BytesOrPtr::Bytes(vec![0u8; p])]) + .interleave(field_bytes_and_ptrs.into_iter()) + .flatten() + .collect::>() + }} + }; + // Generate Phase (quote! { impl<'ink> crate::value::ConcreteValueType<'ink> for #ident { @@ -153,8 +197,38 @@ pub fn as_value_derive(input: TokenStream) -> TokenStream { } } + impl<'ink> crate::value::HasConstValue for #ident { + fn has_const_value() -> bool { + use crate::value::HasConstValue; + #(#field_has_const_values)* + true + } + } + + impl<'ink> crate::value::AsBytesAndPtrs<'ink> for #ident { + fn as_bytes_and_ptrs( + &self, + context: &IrValueContext<'ink, '_, '_> + ) -> Vec> { + use crate::value::AsBytesAndPtrs; + + fn padded_size(align: usize, data_size: usize) -> usize { + ((data_size + align - 1) / align) * align + } + + // Aliasing to make sure that all procedurally generated macros can use the + // same variable name. + let type_context = context.type_context; + let field_padding = #field_padding_bytes; + + #field_bytes_and_ptrs + } + } + impl<'ink> crate::value::AsValue<'ink, #ident> for #ident { fn as_value(&self, context: &crate::value::IrValueContext<'ink, '_, '_>) -> crate::value::Value<'ink, Self> { + use crate::value::HasConstValue; + fn padded_size(align: usize, data_size: usize) -> usize { ((data_size + align - 1) / align) * align } @@ -164,8 +238,9 @@ pub fn as_value_derive(input: TokenStream) -> TokenStream { let type_context = context.type_context; let field_padding = #field_padding_values; - let struct_type = Self::get_ir_type(context.type_context); - // eprintln!("Constructing: {:?}", struct_type.print_to_string().to_string()); + if <#ident>::has_const_value() { + let struct_type = Self::get_ir_type(context.type_context); + // eprintln!("Constructing: {:?}", struct_type.print_to_string().to_string()); let field_values = vec![ #(#field_types_values),* ]; let struct_fields: Vec<_> = field_padding @@ -181,18 +256,61 @@ pub fn as_value_derive(input: TokenStream) -> TokenStream { (context.context.i8_type(), p) }; - let chunks: Vec<_> = (0..num_chunks) - .map(|_| ty.const_int(0, false)) - .collect(); + let chunks: Vec<_> = (0..num_chunks) + .map(|_| ty.const_int(0, false)) + .collect(); - ty.const_array(&chunks).into() - }) - .interleave(field_values.into_iter()) - .collect(); + ty.const_array(&chunks).into() + }) + .interleave(field_values.into_iter()) + .collect(); + + let value = struct_type.const_named_struct(&struct_fields); + // eprintln!("Done"); + crate::value::Value::from_raw(value) + } else { + use crate::value::{AsBytesAndPtrs, BytesOrPtr}; + use inkwell::values::BasicValueEnum; + + let field_bytes_and_ptrs = self + .as_bytes_and_ptrs(context) + .into_iter() + .fold(Vec::new(), |mut v, rhs| { + match rhs { + BytesOrPtr::Bytes(mut rhs) => { + if let Some(BytesOrPtr::Bytes(lhs)) = v.last_mut() { + lhs.append(&mut rhs); + } else { + v.push(BytesOrPtr::Bytes(rhs)); + } + } + BytesOrPtr::UntypedPtr(p) => { + v.push(BytesOrPtr::UntypedPtr(p)); + } + } + v + }); - let value = struct_type.const_named_struct(&struct_fields); - // eprintln!("Done"); - crate::value::Value::from_raw(value) + let byte_ty = ::get_ir_type(context.type_context); + + let field_values: Vec = field_bytes_and_ptrs + .into_iter() + .map(|f| match f { + BytesOrPtr::Bytes(b) => { + let bytes: Vec<_> = b + .into_iter() + .map(|b| byte_ty.const_int(u64::from(b), false)) + .collect(); + + byte_ty.const_array(&bytes).into() + } + BytesOrPtr::UntypedPtr(ptr) => ptr.into(), + }) + .collect(); + + let value = context.context.const_struct(&field_values, true); + Value::from_raw(value) + } } } @@ -202,8 +320,370 @@ pub fn as_value_derive(input: TokenStream) -> TokenStream { Data::Union(_) => { unimplemented!("#[derive(AsValue)] is not defined for unions!"); } - Data::Enum(_) => { - unimplemented!("#[derive(AsValue)] is not defined for enums!"); + Data::Enum(enum_data) => { + const SUPPORTED_TAG_SIZES: &[&str] = + &["u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]; + + let repr_ty = derive_input.attrs.iter().find_map(|a| { + a.parse_meta().map_or(None, |m| { + if let syn::Meta::List(list) = m { + let op = list + .path + .segments + .iter() + .next() + .map(|s| s.ident.to_string()); + + if op == Some("repr".to_string()) { + return list.nested.iter().next().and_then(|n| { + if let syn::NestedMeta::Meta(m) = n { + m.path().segments.iter().next().map(|s| s.ident.clone()) + } else { + None + } + }); + } + } + + None + }) + }); + + let repr_ty = if let Some(ident) = repr_ty { + let repr_ty = ident.to_string(); + if !SUPPORTED_TAG_SIZES.contains(&repr_ty.as_str()) { + eprintln!( + "`repr({})` is not supported by the `AsValue` macro.", + repr_ty + ); + } + + quote! { + #ident + } + } else { + // Default to u32 + quote! { + u32 + } + }; + + if enum_data.variants.is_empty() { + eprintln!("Enums with no variants are not supported by the `AsValue` macro.") + } + + let enum_name = &derive_input.ident; + + // Returns a variant's fields' paddings and the variant's size. + // + // Expects: + // - chunk_size: usize + // - fn padded_size(align: usize, data_size: usize) -> usize + let variant_type_field_paddings_and_sizes = enum_data.variants.iter().map(|v| { + let field_sizes = v.fields.iter().map(|f| { + let ty = &f.ty; + quote! {{ + let ir_type = <#ty>::get_ir_type(type_context); + type_context.target_data.get_store_size(&ir_type) as usize + }} + }); + + let field_alignments = v.fields.iter().map(|f| { + let ty = &f.ty; + quote! {{ + let ir_type = <#ty>::get_ir_type(type_context); + type_context.target_data.get_preferred_alignment(&ir_type) as usize + }} + }); + + quote! {{ + // Start with the tag's size (same as chunk_size) + let mut total_size = chunk_size; + + let field_sizes = [ #(#field_sizes),* ]; + let field_alignments = [ #(#field_alignments),* ]; + + let field_paddings: Vec = field_sizes + .iter() + .zip(field_alignments.iter()) + .map(|(size, align)| { + let padded_size = padded_size(*align, total_size); + let padding = padded_size - total_size; + total_size = padded_size + size; + padding + }) + .collect(); + + ( + field_paddings, + total_size, + ) + }} + }); + + let variant_value_field_paddings_and_sizes = + variant_type_field_paddings_and_sizes.clone(); + + let variant_type_alignments = enum_data.variants.iter().map(|v| { + let field_alignments = v.fields.iter().map(|f| { + let ty = &f.ty; + quote! {{ + let ir_type = <#ty>::get_ir_type(type_context); + type_context.target_data.get_preferred_alignment(&ir_type) as usize + }} + }); + + let variant_align = quote! {{ + let field_alignments = [#(#field_alignments),*]; + field_alignments.iter().max().cloned().unwrap_or(1) + }}; + + variant_align + }); + + let variant_value_alignments = variant_type_alignments.clone(); + + // Generate a list of bytes and `inkwell::values::PointerValue`s for each field. + // + // Expects: + // - context: &IrValueContext + // - enum_size: usize + // - variant_sizes: Vec + let variant_bytes_and_ptrs = { + let variant_bytes_and_ptrs_mapping = enum_data + .variants + .iter() + .enumerate() + .map(|(tag, v)| { + let tag = Index::from(tag); + let field_mappings = v.fields.iter().enumerate().map(|(idx, f)| { + let name = f.ident.as_ref().map(|i| quote! { #i }).unwrap_or_else(|| { + let concatenated = format!("t{}", idx); + let local = Ident::new(&concatenated, Span::call_site()); + let idx = Index::from(idx); + quote! { #idx: #local } + }); + + name + }); + + let field_bytes_and_ptrs = v.fields.iter().enumerate().map(|(idx, f)| { + let name = f.ident.as_ref().map(|i| quote! { #i }).unwrap_or_else(|| { + let concatenated = format!("t{}", idx); + let local = Ident::new(&concatenated, Span::call_site()); + quote! { #local } + }); + + quote! { + #name .as_bytes_and_ptrs(context) + } + }); + + let ident = &v.ident; + quote! { + #enum_name :: #ident { #(#field_mappings),* } => { + let (variant_field_paddings, variant_size) = + variant_field_paddings_and_sizes.get(#tag).expect( + "Number of `variant_field_paddings_and_sizes` does not match the number of variants." + ); + + let variant_field_paddings = variant_field_paddings + .iter() + .map(|p| vec![0u8; *p].into()); + + let field_bytes_and_ptrs = vec![ + vec![BytesOrPtr::Bytes( + bytemuck::cast_ref::<#repr_ty, [u8; std::mem::size_of::<#repr_ty>()]>(&#tag) + .to_vec() + )], + #(#field_bytes_and_ptrs),* + ]; + let mut field_bytes_and_ptrs: Vec<_> = field_bytes_and_ptrs + .iter() + .flatten() + .cloned() + .interleave(variant_field_paddings) + .collect(); + + let rear_padding = enum_size - variant_size; + field_bytes_and_ptrs.push(vec![0u8; rear_padding].into()); + + field_bytes_and_ptrs + } + } + }); + + quote! { + match self { + #(#variant_bytes_and_ptrs_mapping)* + } + } + }; + + // Generate Phase + (quote! { + impl<'ink> crate::value::ConcreteValueType<'ink> for #ident { + type Value = inkwell::values::StructValue<'ink>; + } + + impl<'ink> crate::value::SizedValueType<'ink> for #ident { + fn get_ir_type( + context: &crate::value::IrTypeContext<'ink, '_> + ) -> inkwell::types::StructType<'ink> { + use std::convert::TryFrom; + + let key = std::any::type_name::<#ident>(); + if let Some(value) = context.struct_types.borrow().get(&key) { + return *value; + }; + + // Aliasing to make sure that all procedurally generated macros can use the + // same variable name. + let type_context = context; + + // The chunk size is the same as the tag's size + let chunk_ty = <#repr_ty>::get_ir_type(type_context); + let chunk_size = std::mem::size_of::<#repr_ty>(); + + let variant_alignments = [#(#variant_type_alignments),*]; + let max_align = core::cmp::max( + chunk_size, + variant_alignments.iter().max().cloned().unwrap_or(1), + ); + + fn padded_size(align: usize, data_size: usize) -> usize { + ((data_size + align - 1) / align) * align + } + + let variant_field_paddings_and_sizes = [ #(#variant_type_field_paddings_and_sizes),* ]; + let max_size = variant_field_paddings_and_sizes + .iter() + .map(|(_, s)| *s) + .max() + .unwrap_or(0); + + // Add padding for the end of the variant + let enum_size = padded_size(chunk_size, max_size); + + // The tag is excluded from the number of chunks + let num_chunks = enum_size / chunk_size - 1; + let num_chunks = u32::try_from(num_chunks).expect( + "Number of chunks is too large (max: `u32::max()`)" + ); + + let struct_ty = type_context.context.opaque_struct_type(&key); + type_context.struct_types.borrow_mut().insert(key, struct_ty); + + struct_ty.set_body(&[ + <[#repr_ty; 0]>::get_ir_type(type_context).into(), + chunk_ty.into(), + chunk_ty.array_type(num_chunks).into(), + ], true); + + struct_ty + } + } + + impl<'ink> crate::value::PointerValueType<'ink> for #ident { + fn get_ptr_type(context: &crate::value::IrTypeContext<'ink, '_>, address_space: Option) -> inkwell::types::PointerType<'ink> { + Self::get_ir_type(context).ptr_type(address_space.unwrap_or(inkwell::AddressSpace::Generic)) + } + } + + impl<'ink> crate::value::HasConstValue for #ident { + fn has_const_value() -> bool { + false + } + } + + impl<'ink> crate::value::AsBytesAndPtrs<'ink> for #ident { + fn as_bytes_and_ptrs( + &self, + context: &IrValueContext<'ink, '_, '_> + ) -> Vec> { + use crate::value::{AsBytesAndPtrs, BytesOrPtr}; + + // Aliasing to make sure that all procedurally generated macros can use the + // same variable name. + let type_context = context.type_context; + + // The chunk size is the same as the tag's size + let chunk_ty = <#repr_ty>::get_ir_type(type_context); + let chunk_size = std::mem::size_of::<#repr_ty>(); + + let variant_alignments = [#(#variant_value_alignments),*]; + let max_align = core::cmp::max( + chunk_size, + variant_alignments.iter().max().cloned().unwrap_or(1), + ); + + fn padded_size(align: usize, data_size: usize) -> usize { + ((data_size + align - 1) / align) * align + } + + let variant_field_paddings_and_sizes = [ #(#variant_value_field_paddings_and_sizes),* ]; + + let max_size = variant_field_paddings_and_sizes + .iter() + .map(|(_, s)| *s) + .max() + .unwrap_or(0); + + // Add padding for the end of the variant + let enum_size = padded_size(chunk_size, max_size); + + #variant_bytes_and_ptrs + } + } + + impl<'ink> crate::value::AsValue<'ink, #ident> for #ident { + fn as_value(&self, context: &crate::value::IrValueContext<'ink, '_, '_>) -> crate::value::Value<'ink, Self> { + use crate::value::{AsBytesAndPtrs, BytesOrPtr}; + use inkwell::values::BasicValueEnum; + + let field_bytes_and_ptrs = self + .as_bytes_and_ptrs(context) + .into_iter() + .fold(Vec::new(), |mut v, rhs| { + match rhs { + BytesOrPtr::Bytes(mut rhs) => { + if let Some(BytesOrPtr::Bytes(lhs)) = v.last_mut() { + lhs.append(&mut rhs); + } else { + v.push(BytesOrPtr::Bytes(rhs)); + } + } + BytesOrPtr::UntypedPtr(p) => { + v.push(BytesOrPtr::UntypedPtr(p)); + } + } + v + }); + + let byte_ty = ::get_ir_type(context.type_context); + + let field_values: Vec = field_bytes_and_ptrs + .into_iter() + .map(|f| match f { + BytesOrPtr::Bytes(b) => { + let bytes: Vec<_> = b + .into_iter() + .map(|b| byte_ty.const_int(u64::from(b), false)) + .collect(); + + byte_ty.const_array(&bytes).into() + } + BytesOrPtr::UntypedPtr(ptr) => ptr.into(), + }) + .collect(); + + let value = context.context.const_struct(&field_values, true); + Value::from_raw(value) + } + } + + impl<'ink> crate::value::AddressableType<'ink, #ident> for #ident {} + }).into() } } }