gccrs: TyTy: use new subclass API

gcc/rust/ChangeLog:

	* typecheck/rust-tyty.cc (BaseType::is_unit): Refactor.
	(BaseType::satisfies_bound): Refactor.
	(BaseType::get_root): Refactor.
	(BaseType::destructure): Refactor.
	(BaseType::monomorphized_clone): Refactor.
	(BaseType::is_concrete): Refactor.
	(InferType::InferType): Refactor.
	(InferType::clone): Refactor.
	(InferType::apply_primitive_type_hint): Refactor.
	(StructFieldType::is_equal): Refactor.
	(ADTType::is_equal): Refactor.
	(handle_substitions): Refactor.
	(ADTType::handle_substitions): Refactor.
	(TupleType::TupleType): Refactor.
	(TupleType::is_equal): Refactor.
	(TupleType::handle_substitions): Refactor.

Signed-off-by: Jakub Dupak <dev@jakubdupak.com>
This commit is contained in:
Jakub Dupak 2023-10-05 12:36:08 +02:00 committed by Arthur Cohen
parent 47b88c0338
commit f97a9841dc

View file

@ -220,16 +220,15 @@ BaseType::is_unit () const
return true;
case TUPLE: {
const TupleType &tuple = *static_cast<const TupleType *> (x);
return tuple.num_fields () == 0;
return x->as<const TupleType> ()->num_fields () == 0;
}
case ADT: {
const ADTType &adt = *static_cast<const ADTType *> (x);
if (adt.is_enum ())
auto adt = x->as<const ADTType> ();
if (adt->is_enum ())
return false;
for (const auto &variant : adt.get_variants ())
for (const auto &variant : adt->get_variants ())
{
if (variant->num_fields () > 0)
return false;
@ -276,8 +275,6 @@ bool
BaseType::satisfies_bound (const TypeBoundPredicate &predicate,
bool emit_error) const
{
bool is_infer_var = destructure ()->get_kind () == TyTy::TypeKind::INFER;
const Resolver::TraitReference *query = predicate.get ();
for (const auto &bound : specified_bounds)
{
@ -286,7 +283,7 @@ BaseType::satisfies_bound (const TypeBoundPredicate &predicate,
return true;
}
if (is_infer_var)
if (destructure ()->is<InferType> ())
return true;
bool satisfied = false;
@ -435,28 +432,24 @@ BaseType::get_root () const
{
// FIXME this needs to be it its own visitor class with a vector adjustments
const TyTy::BaseType *root = this;
if (get_kind () == TyTy::REF)
{
const ReferenceType *r = static_cast<const ReferenceType *> (root);
root = r->get_base ()->get_root ();
}
else if (get_kind () == TyTy::POINTER)
{
const PointerType *r = static_cast<const PointerType *> (root);
root = r->get_base ()->get_root ();
}
// these are an unsize
else if (get_kind () == TyTy::SLICE)
if (const auto r = root->try_as<const ReferenceType> ())
{
root = r->get_base ()->get_root ();
}
else if (const auto r = root->try_as<const PointerType> ())
{
root = r->get_base ()->get_root ();
}
// these are an unsize
else if (const auto r = root->try_as<const SliceType> ())
{
const SliceType *r = static_cast<const SliceType *> (root);
root = r->get_element_type ()->get_root ();
}
// else if (get_kind () == TyTy::ARRAY)
// {
// const ArrayType *r = static_cast<const ArrayType *> (root);
// root = r->get_element_type ()->get_root ();
// }
// else if (const auto r = root->try_as<const ArrayType> ())
// {
// root = r->get_element_type ()->get_root ();
// }
return root;
}
@ -478,34 +471,27 @@ BaseType::destructure ()
return new ErrorType (get_ref ());
}
switch (x->get_kind ())
if (auto p = x->try_as<ParamType> ())
{
case TyTy::TypeKind::PARAM: {
TyTy::ParamType *p = static_cast<TyTy::ParamType *> (x);
TyTy::BaseType *pr = p->resolve ();
if (pr == x)
return pr;
auto pr = p->resolve ();
if (pr == x)
return pr;
x = pr;
}
break;
x = pr;
}
else if (auto p = x->try_as<PlaceholderType> ())
{
if (!p->can_resolve ())
return p;
case TyTy::TypeKind::PLACEHOLDER: {
TyTy::PlaceholderType *p = static_cast<TyTy::PlaceholderType *> (x);
if (!p->can_resolve ())
return p;
x = p->resolve ();
}
break;
case TyTy::TypeKind::PROJECTION: {
TyTy::ProjectionType *p = static_cast<TyTy::ProjectionType *> (x);
x = p->get ();
}
break;
default:
x = p->resolve ();
}
else if (auto p = x->try_as<ProjectionType> ())
{
x = p->get ();
}
else
{
return x;
}
}
@ -530,36 +516,27 @@ BaseType::destructure () const
return new ErrorType (get_ref ());
}
switch (x->get_kind ())
if (auto p = x->try_as<const ParamType> ())
{
case TyTy::TypeKind::PARAM: {
const TyTy::ParamType *p = static_cast<const TyTy::ParamType *> (x);
const TyTy::BaseType *pr = p->resolve ();
if (pr == x)
return pr;
auto pr = p->resolve ();
if (pr == x)
return pr;
x = pr;
}
break;
x = pr;
}
else if (auto p = x->try_as<const PlaceholderType> ())
{
if (!p->can_resolve ())
return p;
case TyTy::TypeKind::PLACEHOLDER: {
const TyTy::PlaceholderType *p
= static_cast<const TyTy::PlaceholderType *> (x);
if (!p->can_resolve ())
return p;
x = p->resolve ();
}
break;
case TyTy::TypeKind::PROJECTION: {
const TyTy::ProjectionType *p
= static_cast<const TyTy::ProjectionType *> (x);
x = p->get ();
}
break;
default:
x = p->resolve ();
}
else if (auto p = x->try_as<const ProjectionType> ())
{
x = p->get ();
}
else
{
return x;
}
}
@ -571,112 +548,81 @@ BaseType *
BaseType::monomorphized_clone () const
{
const TyTy::BaseType *x = destructure ();
switch (x->get_kind ())
if (auto arr = x->try_as<const ArrayType> ())
{
TyVar elm = arr->get_var_element_type ().monomorphized_clone ();
return new ArrayType (arr->get_ref (), arr->get_ty_ref (), ident.locus,
arr->get_capacity_expr (), elm,
arr->get_combined_refs ());
}
else if (auto slice = x->try_as<const SliceType> ())
{
TyVar elm = slice->get_var_element_type ().monomorphized_clone ();
return new SliceType (slice->get_ref (), slice->get_ty_ref (),
ident.locus, elm, slice->get_combined_refs ());
}
else if (auto ptr = x->try_as<const PointerType> ())
{
TyVar elm = ptr->get_var_element_type ().monomorphized_clone ();
return new PointerType (ptr->get_ref (), ptr->get_ty_ref (), elm,
ptr->mutability (), ptr->get_combined_refs ());
}
else if (auto ref = x->try_as<const ReferenceType> ())
{
TyVar elm = ref->get_var_element_type ().monomorphized_clone ();
return new ReferenceType (ref->get_ref (), ref->get_ty_ref (), elm,
ref->mutability (), ref->get_combined_refs ());
}
else if (auto tuple = x->try_as<const TupleType> ())
{
std::vector<TyVar> cloned_fields;
for (const auto &f : tuple->get_fields ())
cloned_fields.push_back (f.monomorphized_clone ());
return new TupleType (tuple->get_ref (), tuple->get_ty_ref (),
ident.locus, cloned_fields,
tuple->get_combined_refs ());
}
else if (auto fn = x->try_as<const FnType> ())
{
std::vector<std::pair<HIR::Pattern *, BaseType *>> cloned_params;
for (auto &p : fn->get_params ())
cloned_params.push_back ({p.first, p.second->monomorphized_clone ()});
BaseType *retty = fn->get_return_type ()->monomorphized_clone ();
return new FnType (fn->get_ref (), fn->get_ty_ref (), fn->get_id (),
fn->get_identifier (), fn->ident, fn->get_flags (),
fn->get_abi (), std::move (cloned_params), retty,
fn->clone_substs (), fn->get_combined_refs ());
}
else if (auto fn = x->try_as<const FnPtr> ())
{
std::vector<TyVar> cloned_params;
for (auto &p : fn->get_params ())
cloned_params.push_back (p.monomorphized_clone ());
TyVar retty = fn->get_var_return_type ().monomorphized_clone ();
return new FnPtr (fn->get_ref (), fn->get_ty_ref (), ident.locus,
std::move (cloned_params), retty,
fn->get_combined_refs ());
}
else if (auto adt = x->try_as<const ADTType> ())
{
std::vector<VariantDef *> cloned_variants;
for (auto &variant : adt->get_variants ())
cloned_variants.push_back (variant->monomorphized_clone ());
return new ADTType (adt->get_ref (), adt->get_ty_ref (),
adt->get_identifier (), adt->ident,
adt->get_adt_kind (), cloned_variants,
adt->clone_substs (), adt->get_repr_options (),
adt->get_used_arguments (),
adt->get_combined_refs ());
}
else
{
case PARAM:
case PROJECTION:
case PLACEHOLDER:
case INFER:
case BOOL:
case CHAR:
case INT:
case UINT:
case FLOAT:
case USIZE:
case ISIZE:
case NEVER:
case STR:
case DYNAMIC:
case CLOSURE:
case ERROR:
return x->clone ();
case ARRAY: {
const ArrayType &arr = *static_cast<const ArrayType *> (x);
TyVar elm = arr.get_var_element_type ().monomorphized_clone ();
return new ArrayType (arr.get_ref (), arr.get_ty_ref (), ident.locus,
arr.get_capacity_expr (), elm,
arr.get_combined_refs ());
}
break;
case SLICE: {
const SliceType &slice = *static_cast<const SliceType *> (x);
TyVar elm = slice.get_var_element_type ().monomorphized_clone ();
return new SliceType (slice.get_ref (), slice.get_ty_ref (),
ident.locus, elm, slice.get_combined_refs ());
}
break;
case POINTER: {
const PointerType &ptr = *static_cast<const PointerType *> (x);
TyVar elm = ptr.get_var_element_type ().monomorphized_clone ();
return new PointerType (ptr.get_ref (), ptr.get_ty_ref (), elm,
ptr.mutability (), ptr.get_combined_refs ());
}
break;
case REF: {
const ReferenceType &ref = *static_cast<const ReferenceType *> (x);
TyVar elm = ref.get_var_element_type ().monomorphized_clone ();
return new ReferenceType (ref.get_ref (), ref.get_ty_ref (), elm,
ref.mutability (), ref.get_combined_refs ());
}
break;
case TUPLE: {
const TupleType &tuple = *static_cast<const TupleType *> (x);
std::vector<TyVar> cloned_fields;
for (const auto &f : tuple.get_fields ())
cloned_fields.push_back (f.monomorphized_clone ());
return new TupleType (tuple.get_ref (), tuple.get_ty_ref (),
tuple.get_ident ().locus, cloned_fields,
tuple.get_combined_refs ());
}
break;
case FNDEF: {
const FnType &fn = *static_cast<const FnType *> (x);
std::vector<std::pair<HIR::Pattern *, BaseType *>> cloned_params;
for (auto &p : fn.get_params ())
cloned_params.push_back ({p.first, p.second->monomorphized_clone ()});
BaseType *retty = fn.get_return_type ()->monomorphized_clone ();
return new FnType (fn.get_ref (), fn.get_ty_ref (), fn.get_id (),
fn.get_identifier (), fn.ident, fn.get_flags (),
fn.get_abi (), std::move (cloned_params), retty,
fn.clone_substs (), fn.get_combined_refs ());
}
break;
case FNPTR: {
const FnPtr &fn = *static_cast<const FnPtr *> (x);
std::vector<TyVar> cloned_params;
for (auto &p : fn.get_params ())
cloned_params.push_back (p.monomorphized_clone ());
TyVar retty = fn.get_var_return_type ().monomorphized_clone ();
return new FnPtr (fn.get_ref (), fn.get_ty_ref (), fn.ident.locus,
std::move (cloned_params), retty,
fn.get_combined_refs ());
}
break;
case ADT: {
const ADTType &adt = *static_cast<const ADTType *> (x);
std::vector<VariantDef *> cloned_variants;
for (auto &variant : adt.get_variants ())
cloned_variants.push_back (variant->monomorphized_clone ());
return new ADTType (adt.get_ref (), adt.get_ty_ref (),
adt.get_identifier (), adt.ident,
adt.get_adt_kind (), cloned_variants,
adt.clone_substs (), adt.get_repr_options (),
adt.get_used_arguments (),
adt.get_combined_refs ());
}
break;
}
rust_unreachable ();
@ -714,122 +660,94 @@ bool
BaseType::is_concrete () const
{
const TyTy::BaseType *x = destructure ();
switch (x->get_kind ())
if (x->is<ParamType> () || x->is<ProjectionType> ())
{
case PARAM:
case PROJECTION:
return false;
// placeholder is a special case for this case when it is not resolvable
// it means we its just an empty placeholder associated type which is
// concrete
case PLACEHOLDER:
}
// placeholder is a special case for this case when it is not resolvable
// it means we its just an empty placeholder associated type which is
// concrete
else if (x->is<PlaceholderType> ())
{
return true;
}
else if (auto fn = x->try_as<const FnType> ())
{
for (const auto &param : fn->get_params ())
{
if (!param.second->is_concrete ())
return false;
}
return fn->get_return_type ()->is_concrete ();
}
else if (auto fn = x->try_as<const FnPtr> ())
{
for (const auto &param : fn->get_params ())
{
if (!param.get_tyty ()->is_concrete ())
return false;
}
return fn->get_return_type ()->is_concrete ();
}
else if (auto adt = x->try_as<const ADTType> ())
{
if (adt->is_unit ())
return !adt->needs_substitution ();
case FNDEF: {
const FnType &fn = *static_cast<const FnType *> (x);
for (const auto &param : fn.get_params ())
{
const BaseType *p = param.second;
if (!p->is_concrete ())
return false;
}
return fn.get_return_type ()->is_concrete ();
}
break;
for (auto &variant : adt->get_variants ())
{
bool is_num_variant
= variant->get_variant_type () == VariantDef::VariantType::NUM;
if (is_num_variant)
continue;
case FNPTR: {
const FnPtr &fn = *static_cast<const FnPtr *> (x);
for (const auto &param : fn.get_params ())
{
const BaseType *p = param.get_tyty ();
if (!p->is_concrete ())
return false;
}
return fn.get_return_type ()->is_concrete ();
}
break;
case ADT: {
const ADTType &adt = *static_cast<const ADTType *> (x);
if (adt.is_unit ())
{
return !adt.needs_substitution ();
}
for (auto &variant : adt.get_variants ())
{
bool is_num_variant
= variant->get_variant_type () == VariantDef::VariantType::NUM;
if (is_num_variant)
continue;
for (auto &field : variant->get_fields ())
{
const BaseType *field_type = field->get_field_type ();
if (!field_type->is_concrete ())
return false;
}
}
return true;
}
break;
case ARRAY: {
const ArrayType &arr = *static_cast<const ArrayType *> (x);
return arr.get_element_type ()->is_concrete ();
}
break;
case SLICE: {
const SliceType &slice = *static_cast<const SliceType *> (x);
return slice.get_element_type ()->is_concrete ();
}
break;
case POINTER: {
const PointerType &ptr = *static_cast<const PointerType *> (x);
return ptr.get_base ()->is_concrete ();
}
break;
case REF: {
const ReferenceType &ref = *static_cast<const ReferenceType *> (x);
return ref.get_base ()->is_concrete ();
}
break;
case TUPLE: {
const TupleType &tuple = *static_cast<const TupleType *> (x);
for (size_t i = 0; i < tuple.num_fields (); i++)
{
if (!tuple.get_field (i)->is_concrete ())
return false;
}
return true;
}
break;
case CLOSURE: {
const ClosureType &closure = *static_cast<const ClosureType *> (x);
if (closure.get_parameters ().is_concrete ())
return false;
return closure.get_result_type ().is_concrete ();
}
break;
case INFER:
case BOOL:
case CHAR:
case INT:
case UINT:
case FLOAT:
case USIZE:
case ISIZE:
case NEVER:
case STR:
case DYNAMIC:
case ERROR:
for (auto &field : variant->get_fields ())
{
const BaseType *field_type = field->get_field_type ();
if (!field_type->is_concrete ())
return false;
}
}
return true;
}
else if (auto arr = x->try_as<const ArrayType> ())
{
return arr->get_element_type ()->is_concrete ();
}
else if (auto slice = x->try_as<const SliceType> ())
{
return slice->get_element_type ()->is_concrete ();
}
else if (auto ptr = x->try_as<const PointerType> ())
{
return ptr->get_base ()->is_concrete ();
}
else if (auto ref = x->try_as<const ReferenceType> ())
{
return ref->get_base ()->is_concrete ();
}
else if (auto tuple = x->try_as<const TupleType> ())
{
for (size_t i = 0; i < tuple->num_fields (); i++)
{
if (!tuple->get_field (i)->is_concrete ())
return false;
}
return true;
}
else if (auto closure = x->try_as<const ClosureType> ())
{
if (closure->get_parameters ().is_concrete ())
return false;
return closure->get_result_type ().is_concrete ();
}
else if (x->is<InferType> () || x->is<BoolType> () || x->is<CharType> ()
|| x->is<IntType> () || x->is<UintType> () || x->is<FloatType> ()
|| x->is<USizeType> () || x->is<ISizeType> () || x->is<NeverType> ()
|| x->is<StrType> () || x->is<DynamicObjectType> ()
|| x->is<ErrorType> ())
{
return true;
}
@ -1197,10 +1115,9 @@ InferType::apply_primitive_type_hint (const BaseType &hint)
case INT: {
infer_kind = INTEGRAL;
const IntType &i = static_cast<const IntType &> (hint);
default_hint.kind = hint.get_kind ();
default_hint.shint = TypeHint::SignedHint::SIGNED;
switch (i.get_int_kind ())
switch (hint.as<const IntType> ()->get_int_kind ())
{
case IntType::I8:
default_hint.szhint = TypeHint::SizeHint::S8;
@ -1223,10 +1140,9 @@ InferType::apply_primitive_type_hint (const BaseType &hint)
case UINT: {
infer_kind = INTEGRAL;
const UintType &i = static_cast<const UintType &> (hint);
default_hint.kind = hint.get_kind ();
default_hint.shint = TypeHint::SignedHint::UNSIGNED;
switch (i.get_uint_kind ())
switch (hint.as<const UintType> ()->get_uint_kind ())
{
case UintType::U8:
default_hint.szhint = TypeHint::SizeHint::S8;
@ -1251,8 +1167,7 @@ InferType::apply_primitive_type_hint (const BaseType &hint)
infer_kind = FLOAT;
default_hint.shint = TypeHint::SignedHint::SIGNED;
default_hint.kind = hint.get_kind ();
const FloatType &i = static_cast<const FloatType &> (hint);
switch (i.get_float_kind ())
switch (hint.as<const FloatType> ()->get_float_kind ())
{
case FloatType::F32:
default_hint.szhint = TypeHint::SizeHint::S32;
@ -1371,14 +1286,11 @@ StructFieldType::as_string () const
bool
StructFieldType::is_equal (const StructFieldType &other) const
{
bool names_eq = get_name ().compare (other.get_name ()) == 0;
bool names_eq = get_name () == other.get_name ();
TyTy::BaseType *o = other.get_field_type ();
if (o->get_kind () == TypeKind::PARAM)
{
ParamType *op = static_cast<ParamType *> (o);
o = op->resolve ();
}
if (auto op = o->try_as<ParamType> ())
o = op->resolve ();
bool types_eq = get_field_type ()->is_equal (*o);
@ -1673,25 +1585,25 @@ ADTType::is_equal (const BaseType &other) const
if (get_kind () != other.get_kind ())
return false;
auto other2 = static_cast<const ADTType &> (other);
if (get_adt_kind () != other2.get_adt_kind ())
auto other2 = other.as<const ADTType> ();
if (get_adt_kind () != other2->get_adt_kind ())
return false;
if (number_of_variants () != other2.number_of_variants ())
if (number_of_variants () != other2->number_of_variants ())
return false;
if (has_substitutions_defined () != other2.has_substitutions_defined ())
if (has_substitutions_defined () != other2->has_substitutions_defined ())
return false;
if (has_substitutions_defined ())
{
if (get_num_substitutions () != other2.get_num_substitutions ())
if (get_num_substitutions () != other2->get_num_substitutions ())
return false;
for (size_t i = 0; i < get_num_substitutions (); i++)
{
const SubstitutionParamMapping &a = substitutions.at (i);
const SubstitutionParamMapping &b = other2.substitutions.at (i);
const SubstitutionParamMapping &b = other2->substitutions.at (i);
const ParamType *aa = a.get_param_ty ();
const ParamType *bb = b.get_param_ty ();
@ -1705,7 +1617,7 @@ ADTType::is_equal (const BaseType &other) const
for (size_t i = 0; i < number_of_variants (); i++)
{
const TyTy::VariantDef *a = get_variants ().at (i);
const TyTy::VariantDef *b = other2.get_variants ().at (i);
const TyTy::VariantDef *b = other2->get_variants ().at (i);
if (!a->is_equal (*b))
return false;
@ -1732,11 +1644,8 @@ handle_substitions (SubstitutionArgumentMappings &subst_mappings,
StructFieldType *field)
{
auto fty = field->get_field_type ();
bool is_param_ty = fty->get_kind () == TypeKind::PARAM;
if (is_param_ty)
if (auto p = fty->try_as<ParamType> ())
{
ParamType *p = static_cast<ParamType *> (fty);
SubstitutionArg arg = SubstitutionArg::error ();
bool ok = subst_mappings.get_argument_for_symbol (p, &arg);
if (ok)
@ -1781,7 +1690,7 @@ handle_substitions (SubstitutionArgumentMappings &subst_mappings,
ADTType *
ADTType::handle_substitions (SubstitutionArgumentMappings &subst_mappings)
{
ADTType *adt = static_cast<ADTType *> (clone ());
auto adt = clone ()->as<ADTType> ();
adt->set_ty_ref (mappings->get_next_hir_id ());
adt->used_arguments = subst_mappings;
@ -1905,13 +1814,13 @@ TupleType::is_equal (const BaseType &other) const
if (get_kind () != other.get_kind ())
return false;
auto other2 = static_cast<const TupleType &> (other);
if (num_fields () != other2.num_fields ())
auto other2 = other.as<const TupleType> ();
if (num_fields () != other2->num_fields ())
return false;
for (size_t i = 0; i < num_fields (); i++)
{
if (!get_field (i)->is_equal (*other2.get_field (i)))
if (!get_field (i)->is_equal (*other2->get_field (i)))
return false;
}
return true;
@ -1933,7 +1842,7 @@ TupleType::handle_substitions (SubstitutionArgumentMappings &mappings)
{
auto mappings_table = Analysis::Mappings::get ();
TupleType *tuple = static_cast<TupleType *> (clone ());
auto tuple = clone ()->as<TupleType> ();
tuple->set_ref (mappings_table->get_next_hir_id ());
tuple->set_ty_ref (mappings_table->get_next_hir_id ());
@ -3730,7 +3639,8 @@ ProjectionType::handle_substitions (
SubstitutionArgumentMappings &subst_mappings)
{
// // do we really need to substitute this?
// if (base->needs_generic_substitutions () || base->contains_type_parameters
// if (base->needs_generic_substitutions () ||
// base->contains_type_parameters
// ())
// {
// return this;