Skip to content

Commit

Permalink
sort unions and intersections lazily
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Jan 17, 2025
1 parent 1591afe commit 0034a03
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 30 deletions.
85 changes: 58 additions & 27 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use itertools::Itertools;
use ruff_db::diagnostic::Severity;
use ruff_db::files::File;
use ruff_python_ast as ast;
use type_ordering::order_union_elements;

pub(crate) use self::builder::{IntersectionBuilder, UnionBuilder};
pub(crate) use self::diagnostic::register_lints;
Expand Down Expand Up @@ -1084,9 +1085,15 @@ impl<'db> Type<'db> {

// TODO equivalent but not identical structural types

// For all other cases, types are equivalent iff they have the same internal
// representation.
self == other
match (self, other) {
(Type::Union(left), Type::Union(right)) => {
left.to_sorted_union(db) == right.to_sorted_union(db)
}
(Type::Intersection(left), Type::Intersection(right)) => {
left.to_sorted_intersection(db) == right.to_sorted_intersection(db)
}
_ => self == other,
}
}

/// Returns true if both `self` and `other` are the same gradual form
Expand Down Expand Up @@ -1128,16 +1135,6 @@ impl<'db> Type<'db> {

(Type::Dynamic(_), Type::Dynamic(_)) => true,

(Type::Instance(instance), Type::SubclassOf(subclass))
| (Type::SubclassOf(subclass), Type::Instance(instance)) => {
let Some(base_class) = subclass.subclass_of().into_class() else {
return false;
};

instance.class.is_known(db, KnownClass::Type)
&& base_class.is_known(db, KnownClass::Object)
}

(Type::SubclassOf(first), Type::SubclassOf(second)) => {
match (first.subclass_of(), second.subclass_of()) {
(first, second) if first == second => true,
Expand All @@ -1154,33 +1151,40 @@ impl<'db> Type<'db> {
&& iter::zip(first_elements, second_elements).all(equivalent)
}

// TODO: Handle equivalent unions with items in different order
(Type::Union(first), Type::Union(second)) => {
let first_elements = first.elements(db);
let second_elements = second.elements(db);
let first = first.to_sorted_union(db);
let second = second.to_sorted_union(db);

if first_elements.len() != second_elements.len() {
return false;
if first == second {
return true;
}

iter::zip(first_elements, second_elements).all(equivalent)
let first_elements = first.elements(db);
let second_elements = second.elements(db);

// TODO: Unknown ≡ Any, a union might contain both, etc.
first_elements.len() == second_elements.len()
&& iter::zip(first_elements, second_elements).all(equivalent)
}

// TODO: Handle equivalent intersections with items in different order
(Type::Intersection(first), Type::Intersection(second)) => {
let first = first.to_sorted_intersection(db);
let second = second.to_sorted_intersection(db);

if first == second {
return true;
}

let first_positive = first.positive(db);
let first_negative = first.negative(db);

let second_positive = second.positive(db);
let second_negative = second.negative(db);

if first_positive.len() != second_positive.len()
|| first_negative.len() != second_negative.len()
{
return false;
}

iter::zip(first_positive, second_positive).all(equivalent)
// TODO: Unknown ≡ Any, an intersection might contain both, etc.
first_positive.len() == second_positive.len()
&& first_negative.len() == second_negative.len()
&& iter::zip(first_positive, second_positive).all(equivalent)
&& iter::zip(first_negative, second_negative).all(equivalent)
}

Expand Down Expand Up @@ -4180,6 +4184,7 @@ pub struct UnionType<'db> {
elements_boxed: Box<[Type<'db>]>,
}

#[salsa::tracked]
impl<'db> UnionType<'db> {
fn elements(self, db: &'db dyn Db) -> &'db [Type<'db>] {
self.elements_boxed(db)
Expand Down Expand Up @@ -4208,6 +4213,18 @@ impl<'db> UnionType<'db> {
) -> Type<'db> {
Self::from_elements(db, self.elements(db).iter().map(transform_fn))
}

#[salsa::tracked]
fn to_sorted_union(self, db: &'db dyn Db) -> UnionType<'db> {
let mut elements = self.elements_boxed(db).to_vec();
for element in &mut elements {
if let Type::Intersection(intersection) = element {
*element = Type::Intersection(intersection.to_sorted_intersection(db));
}
}
elements.sort_unstable_by(|left, right| order_union_elements(db, left, right));
UnionType::new(db, elements.into_boxed_slice())
}
}

#[salsa::interned]
Expand All @@ -4225,6 +4242,20 @@ pub struct IntersectionType<'db> {
negative: FxOrderSet<Type<'db>>,
}

#[salsa::tracked]
impl<'db> IntersectionType<'db> {
#[salsa::tracked]
fn to_sorted_intersection(self, db: &'db dyn Db) -> IntersectionType<'db> {
let mut positive = self.positive(db).clone();
positive.sort_unstable_by(|left, right| order_union_elements(db, left, right));

let mut negative = self.negative(db).clone();
negative.sort_unstable_by(|left, right| order_union_elements(db, left, right));

IntersectionType::new(db, positive, negative)
}
}

#[salsa::interned]
pub struct StringLiteralType<'db> {
#[return_ref]
Expand Down
18 changes: 15 additions & 3 deletions crates/red_knot_python_semantic/src/types/type_ordering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,26 @@ pub(super) fn order_union_elements<'db>(
(_, Type::Dynamic(_)) => Ordering::Greater,

(Type::Union(left), Type::Union(right)) => {
order_sequences(db, left.elements(db), right.elements(db))
let left = left.to_sorted_union(db);
let right = right.to_sorted_union(db);
if left == right {
Ordering::Equal
} else {
order_sequences(db, left.elements(db), right.elements(db))
}
}
(Type::Union(_), _) => Ordering::Less,
(_, Type::Union(_)) => Ordering::Greater,

(Type::Intersection(left), Type::Intersection(right)) => {
order_sequences(db, left.positive(db), right.positive(db))
.then_with(|| order_sequences(db, left.negative(db), right.negative(db)))
let left = left.to_sorted_intersection(db);
let right = right.to_sorted_intersection(db);
if left == right {
Ordering::Equal
} else {
order_sequences(db, left.positive(db), right.positive(db))
.then_with(|| order_sequences(db, left.negative(db), right.negative(db)))
}
}
}
}
Expand Down

0 comments on commit 0034a03

Please sign in to comment.