Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
weiznich committed Aug 4, 2020
1 parent 577e224 commit c4d15b0
Show file tree
Hide file tree
Showing 13 changed files with 63 additions and 48 deletions.
2 changes: 1 addition & 1 deletion diesel/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ where
DB: Backend,
{
/// Construct an instance of `Self` from the database row
fn build(row: &impl NamedRow<DB>) -> Result<Self>;
fn build<'a>(row: &impl NamedRow<'a, DB>) -> Result<Self>;
}

#[doc(inline)]
Expand Down
2 changes: 1 addition & 1 deletion diesel/src/mysql/connection/stmt/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ pub struct MysqlField<'a> {
}

impl<'a> Field<'a, Mysql> for MysqlField<'a> {
fn field_name(&self) -> Option<&str> {
fn field_name(&self) -> Option<&'a str> {
self.metadata.field_name()
}

Expand Down
7 changes: 6 additions & 1 deletion diesel/src/mysql/connection/stmt/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ impl<'a> MysqlFieldMetadata<'a> {
if self.0.name.is_null() {
None
} else {
unsafe { CStr::from_ptr(self.0.name).to_str().ok() }
unsafe {
Some(CStr::from_ptr(self.0.name).to_str().expect(
"Expect mysql field names to be UTF-8, because we \
requested UTF-8 encoding on connection setup",
))
}
}
}

Expand Down
7 changes: 3 additions & 4 deletions diesel/src/pg/connection/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ impl<'a> Iterator for Cursor<'a> {
type Item = PgRow<'a>;

fn next(&mut self) -> Option<Self::Item> {
if self.current_row >= self.db_result.num_rows() {
None
} else {
if self.current_row < self.db_result.num_rows() {
let row = self.db_result.get_row(self.current_row);
self.current_row += 1;

Some(row)
} else {
None
}
}

Expand Down
16 changes: 9 additions & 7 deletions diesel/src/pg/connection/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,11 @@ impl PgResult {
}

pub fn column_type(&self, col_idx: usize) -> NonZeroU32 {
unsafe {
NonZeroU32::new_unchecked(PQftype(
self.internal_result.as_ptr(),
col_idx as libc::c_int,
))
}
let type_oid = unsafe { PQftype(self.internal_result.as_ptr(), col_idx as libc::c_int) };
NonZeroU32::new(type_oid).expect(
"Got a zero oid from postgres. If you see this error message \
please report it as issue on the diesel github bug tracker.",
)
}

pub fn column_name(&self, col_idx: usize) -> Option<&str> {
Expand All @@ -128,7 +127,10 @@ impl PgResult {
if ptr.is_null() {
None
} else {
Some(CStr::from_ptr(ptr).to_str().expect("Utf8"))
Some(CStr::from_ptr(ptr).to_str().expect(
"Expect postgres field names to be UTF-8, because we \
requested UTF-8 encoding on connection setup",
))
}
}
}
Expand Down
22 changes: 11 additions & 11 deletions diesel/src/pg/connection/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,11 @@ impl<'a> Row<'a, Pg> for PgRow<'a> {
Self: RowIndex<I>,
{
let idx = self.idx(idx)?;
if idx < self.field_count() {
Some(PgField {
db_result: self.db_result,
row_idx: self.row_idx,
col_idx: idx,
})
} else {
None
}
Some(PgField {
db_result: self.db_result,
row_idx: self.row_idx,
col_idx: idx,
})
}

fn partial_row(&self, range: std::ops::Range<usize>) -> PartialRow<Self::InnerPartialRow> {
Expand All @@ -45,7 +41,11 @@ impl<'a> Row<'a, Pg> for PgRow<'a> {

impl<'a> RowIndex<usize> for PgRow<'a> {
fn idx(&self, idx: usize) -> Option<usize> {
Some(idx)
if idx < self.field_count() {
Some(idx)
} else {
None
}
}
}

Expand All @@ -62,7 +62,7 @@ pub struct PgField<'a> {
}

impl<'a> Field<'a, Pg> for PgField<'a> {
fn field_name(&self) -> Option<&str> {
fn field_name(&self) -> Option<&'a str> {
self.db_result.column_name(self.col_idx)
}

Expand Down
31 changes: 18 additions & 13 deletions diesel/src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,17 @@ pub struct PartialRow<'a, R> {

impl<'a, R> PartialRow<'a, R> {
#[doc(hidden)]
pub fn new(inner: &'a R, range: Range<usize>) -> Self {
Self { inner, range }
pub fn new<'b, DB>(inner: &'a R, range: Range<usize>) -> Self
where
R: Row<'b, DB>,
DB: Backend,
{
let range_lower = std::cmp::min(range.start, inner.field_count());
let range_upper = std::cmp::min(range.end, inner.field_count());
Self {
inner,
range: range_lower..range_upper,
}
}
}

Expand All @@ -110,24 +119,20 @@ where
type InnerPartialRow = R;

fn field_count(&self) -> usize {
let inner_length = self.inner.field_count();
if self.range.start < inner_length {
std::cmp::min(inner_length - self.range.start, self.range.len())
} else {
0
}
self.range.len()
}

fn get<I>(&self, idx: I) -> Option<Self::Field>
where
Self: RowIndex<I>,
{
let idx = self.idx(idx)?;
Some(self.inner.get(idx).unwrap())
self.inner.get(idx)
}

fn partial_row(&self, range: Range<usize>) -> PartialRow<R> {
let range = (self.range.start + range.start)..(self.range.start + range.end);
let range_upper_bound = std::cmp::min(self.range.end, self.range.start + range.end);
let range = (self.range.start + range.start)..range_upper_bound;
PartialRow {
inner: self.inner,
range,
Expand Down Expand Up @@ -168,7 +173,7 @@ where
///
/// This trait is used by implementations of
/// [`QueryableByName`](../deserialize/trait.QueryableByName.html)
pub trait NamedRow<DB: Backend> {
pub trait NamedRow<'a, DB: Backend>: Row<'a, DB> {
/// Retrieve and deserialize a single value from the query
///
/// Note that `ST` *must* be the exact type of the value with that name in
Expand All @@ -178,12 +183,12 @@ pub trait NamedRow<DB: Backend> {
///
/// If two or more fields in the query have the given name, the result of
/// this function is undefined.
fn get<'a, ST, T>(&self, column_name: &'a str) -> deserialize::Result<T>
fn get<'b, ST, T>(&self, column_name: &'b str) -> deserialize::Result<T>
where
T: FromSql<ST, DB>;
}

impl<'a, R, DB> NamedRow<DB> for R
impl<'a, R, DB> NamedRow<'a, DB> for R
where
R: Row<'a, DB>,
DB: Backend,
Expand Down
8 changes: 6 additions & 2 deletions diesel/src/sqlite/connection/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ impl<'a> Row<'a, Sqlite> for FunctionRow<'a> {

impl<'a> RowIndex<usize> for FunctionRow<'a> {
fn idx(&self, idx: usize) -> Option<usize> {
Some(idx)
if idx < self.args.len() {
Some(idx)
} else {
None
}
}
}

Expand All @@ -152,7 +156,7 @@ struct FunctionArgument<'a> {
}

impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> {
fn field_name(&self) -> Option<&str> {
fn field_name(&self) -> Option<&'a str> {
None
}

Expand Down
2 changes: 1 addition & 1 deletion diesel/src/sqlite/connection/sqlite_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ pub struct SqliteField<'a> {
}

impl<'a> Field<'a, Sqlite> for SqliteField<'a> {
fn field_name(&self) -> Option<&str> {
fn field_name(&self) -> Option<&'a str> {
column_name(self.stmt, self.col_idx)
}

Expand Down
2 changes: 1 addition & 1 deletion diesel/src/type_impls/option.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ where
DB: Backend,
T: QueryableByName<DB>,
{
fn build(row: &impl crate::row::NamedRow<DB>) -> deserialize::Result<Self> {
fn build<'a>(row: &impl crate::row::NamedRow<'a, DB>) -> deserialize::Result<Self> {
match T::build(row) {
Ok(v) => Ok(Some(v)),
Err(e) if e.is::<crate::result::UnexpectedNullError>() => Ok(None),
Expand Down
6 changes: 3 additions & 3 deletions diesel_derives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,9 +630,9 @@ pub fn derive_queryable(input: TokenStream) -> TokenStream {
/// i32: FromSql<diesel::dsl::SqlTypeOf<users::id>, DB>,
/// String: FromSql<diesel::dsl::SqlTypeOf<users::name>, DB>,
/// {
/// fn build(row: &impl NamedRow<DB>) -> deserialize::Result<Self> {
/// let id = row.get::<diesel::dsl::SqlTypeOf<users::id>, _>("id")?;
/// let name = row.get::<diesel::dsl::SqlTypeOf<users::name>, _>("name")?;
/// fn build<'a>(row: &impl NamedRow<'a, DB>) -> deserialize::Result<Self> {
/// let id = NamedRow::get::<diesel::dsl::SqlTypeOf<users::id>, _>(row, "id")?;
/// let name = NamedRow::get::<diesel::dsl::SqlTypeOf<users::name>, _>(row, "name")?;
///
/// Ok(Self { id, name })
/// }
Expand Down
4 changes: 2 additions & 2 deletions diesel_derives/src/queryable_by_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno
let deserialize_ty = f.ty_for_deserialize()?;
Ok(quote!(
{
let field = row.get(stringify!(#name))?;
let field = diesel::row::NamedRow::get(row, stringify!(#name))?;
<#deserialize_ty as Into<#field_ty>>::into(field)
}
))
Expand Down Expand Up @@ -67,7 +67,7 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno
for #struct_name #ty_generics
#where_clause
{
fn build(row: &impl NamedRow<__DB>) -> deserialize::Result<Self>
fn build<'__a>(row: &impl NamedRow<'__a, __DB>) -> deserialize::Result<Self>
{


Expand Down
2 changes: 1 addition & 1 deletion diesel_tests/tests/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ fn boolean_from_sql() {
}

#[test]
fn boolean_treats_null_as_false_when_predicates_return_null() {
fn nullable_boolean_from_sql() {
let connection = connection();
let one = Some(1).into_sql::<diesel::sql_types::Nullable<Integer>>();
let query = select(one.eq(None::<i32>));
Expand Down

0 comments on commit c4d15b0

Please sign in to comment.