From ec5eba6e1bbcf3457d5ae3ffdb3c905fa77ea0ab Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 10 Sep 2023 22:50:23 +0200 Subject: [PATCH 01/15] Also test 2-arg variants --- test/blas.jl | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/test/blas.jl b/test/blas.jl index 463eb7cee5..38caa3915d 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -19,15 +19,24 @@ using Test ), Tx in (Const, Duplicated, BatchDuplicated), Ty in (Const, Duplicated, BatchDuplicated), - T in (fun == BLAS.dot ? RTs : RCs), - (sz, inc) in ((10, 1), ((2, 20), -2)) + T in (fun == BLAS.dot ? RTs : RCs) are_activities_compatible(Tret, Tx, Ty) || continue - - x = randn(T, sz) - y = randn(T, sz) atol = rtol = sqrt(eps(real(T))) - test_forward(fun, Tret, n, (x, Tx), inc, (y, Ty), inc; atol, rtol) + + @testset "$fun(n, x, incx, y, incy)" begin + @testset for (sz, inc) in ((10, 1), ((2, 20), -2)) + x = randn(T, sz) + y = randn(T, sz) + test_forward(fun, Tret, n, (x, Tx), inc, (y, Ty), inc; atol, rtol) + end + end + + @testset "$fun(x, y)" begin + x = randn(T, n) + y = randn(T, n) + test_forward(fun, Tret, (x, Tx), (y, Ty); atol, rtol) + end end end @@ -35,15 +44,24 @@ using Test @testset for Tret in (Const, Active), Tx in (Const, Duplicated, BatchDuplicated), Ty in (Const, Duplicated, BatchDuplicated), - T in (fun == BLAS.dot ? RTs : RCs), - (sz, inc) in ((10, 1), ((2, 20), -2)) + T in (fun == BLAS.dot ? RTs : RCs) are_activities_compatible(Tret, Tx, Ty) || continue - - x = randn(T, sz) - y = randn(T, sz) atol = rtol = sqrt(eps(real(T))) - test_reverse(fun, Tret, n, (x, Tx), inc, (y, Ty), inc; atol, rtol) + + @testset "$fun(n, x, incx, y, incy)" begin + @testset for (sz, inc) in ((10, 1), ((2, 20), -2)) + x = randn(T, sz) + y = randn(T, sz) + test_reverse(fun, Tret, n, (x, Tx), inc, (y, Ty), inc; atol, rtol) + end + end + + @testset "$fun(x, y)" begin + x = randn(T, n) + y = randn(T, n) + test_reverse(fun, Tret, (x, Tx), (y, Ty); atol, rtol) + end end end end From b05fb1029bd17a5a320a174e59f88d1bdbb8329c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 10 Sep 2023 22:53:03 +0200 Subject: [PATCH 02/15] Use more informative variable names --- test/blas.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/blas.jl b/test/blas.jl index 38caa3915d..c249345539 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -4,8 +4,8 @@ using LinearAlgebra using Test @testset "BLAS rules" begin - RTs = (Float32, Float64) - RCs = (ComplexF32, ComplexF64) + BLASReals = (Float32, Float64) + BLASFloats = (ComplexF32, ComplexF64) n = 10 @testset for fun in (BLAS.dot, BLAS.dotu, BLAS.dotc) @@ -19,7 +19,7 @@ using Test ), Tx in (Const, Duplicated, BatchDuplicated), Ty in (Const, Duplicated, BatchDuplicated), - T in (fun == BLAS.dot ? RTs : RCs) + T in (fun == BLAS.dot ? BLASReals : BLASFloats) are_activities_compatible(Tret, Tx, Ty) || continue atol = rtol = sqrt(eps(real(T))) @@ -44,7 +44,7 @@ using Test @testset for Tret in (Const, Active), Tx in (Const, Duplicated, BatchDuplicated), Ty in (Const, Duplicated, BatchDuplicated), - T in (fun == BLAS.dot ? RTs : RCs) + T in (fun == BLAS.dot ? BLASReals : BLASFloats) are_activities_compatible(Tret, Tx, Ty) || continue atol = rtol = sqrt(eps(real(T))) From e27aa6af399b74a33dd0dffbb9826fb7706f289e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 11 Sep 2023 23:13:56 +0200 Subject: [PATCH 03/15] Add `BLAS.scal!` forward tests --- test/blas.jl | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/test/blas.jl b/test/blas.jl index c249345539..388d62215b 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -8,6 +8,39 @@ using Test BLASFloats = (ComplexF32, ComplexF64) n = 10 + @testset "BLAS.scal!" begin + @testset "forward" begin + @testset for Tret in ( + Const, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + ), + Ta in (Const, Duplicated, BatchDuplicated), + Tx in (Duplicated, BatchDuplicated), + T in BLASFloats + + are_activities_compatible(Tret, Ta, Tx) || continue + atol = rtol = sqrt(eps(real(T))) + + @testset "BLAS.scal!(n, a, x, incx)" begin + @testset for (sz, inc) in ((10, 1), ((2, 20), -2)) + a = randn(T) + x = randn(T, sz) + test_forward(BLAS.scal!, Tret, n, (a, Ta), (x, Tx), inc; atol, rtol) + end + end + + @testset "BLAS.scal!(a, x)" begin + a = randn(T) + x = randn(T, n) + test_forward(BLAS.scal!, Tret, (a, Ta), (x, Tx); atol, rtol) + end + end + end + end + @testset for fun in (BLAS.dot, BLAS.dotu, BLAS.dotc) @testset "forward" begin @testset for Tret in ( From eae5f329cac17d42d3e57c498bb4e377a7c55127 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 16:23:27 +0200 Subject: [PATCH 04/15] Add MetaTesting as test dependency --- test/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Project.toml b/test/Project.toml index bf2cbe5d50..db408dbd2a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,7 @@ InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MetaTesting = "9e32d19f-1e4f-477a-8631-b16c78aa0f56" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" From 2bbf22e62b212da27e047a7a24485cf7327467bb Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 16:23:46 +0200 Subject: [PATCH 05/15] Fix definition of `BLASFloats` --- test/blas.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/blas.jl b/test/blas.jl index 388d62215b..44cd05de87 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -5,7 +5,8 @@ using Test @testset "BLAS rules" begin BLASReals = (Float32, Float64) - BLASFloats = (ComplexF32, ComplexF64) + BLASComplexes = (ComplexF32, ComplexF64) + BLASFloats = (BLASReals..., BLASComplexes...) n = 10 @testset "BLAS.scal!" begin From 0df7877d87a498e66dac4a43058ba800d69adac7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 16:24:04 +0200 Subject: [PATCH 06/15] Limit tests to complexes --- test/blas.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/blas.jl b/test/blas.jl index 44cd05de87..5eababbc86 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -53,7 +53,7 @@ using Test ), Tx in (Const, Duplicated, BatchDuplicated), Ty in (Const, Duplicated, BatchDuplicated), - T in (fun == BLAS.dot ? BLASReals : BLASFloats) + T in (fun == BLAS.dot ? BLASReals : BLASComplexes) are_activities_compatible(Tret, Tx, Ty) || continue atol = rtol = sqrt(eps(real(T))) @@ -78,7 +78,7 @@ using Test @testset for Tret in (Const, Active), Tx in (Const, Duplicated, BatchDuplicated), Ty in (Const, Duplicated, BatchDuplicated), - T in (fun == BLAS.dot ? BLASReals : BLASFloats) + T in (fun == BLAS.dot ? BLASReals : BLASComplexes) are_activities_compatible(Tret, Tx, Ty) || continue atol = rtol = sqrt(eps(real(T))) From 8c32f4252a53eb1cb9a880b38d5ebbb23e668e1b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 16:24:22 +0200 Subject: [PATCH 07/15] Workaround broken tests --- test/blas.jl | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/test/blas.jl b/test/blas.jl index 5eababbc86..43421b6417 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -1,6 +1,7 @@ using Enzyme using EnzymeTestUtils using LinearAlgebra +using MetaTesting using Test @testset "BLAS rules" begin @@ -29,14 +30,62 @@ using Test @testset for (sz, inc) in ((10, 1), ((2, 20), -2)) a = randn(T) x = randn(T, sz) - test_forward(BLAS.scal!, Tret, n, (a, Ta), (x, Tx), inc; atol, rtol) + @test !fails() do + test_forward( + BLAS.scal!, Tret, n, (a, Ta), (x, Tx), inc; atol, rtol + ) + end broken = (T <: ComplexF32 && !(Ta <: Const) && !(Tx <: Const)) + end + end + @testset "BLAS.scal!(a, x)" begin + a = randn(T) + x = randn(T, n) + @test !fails() do + test_forward(BLAS.scal!, Tret, (a, Ta), (x, Tx); atol, rtol) + end broken = (T <: ComplexF32 && !(Ta <: Const) && !(Tx <: Const)) + end + end + end + + @testset "reverse" begin + @testset for Tret in (Const,), + Ta in (Const, Active), + Tx in (Duplicated, BatchDuplicated), + T in BLASFloats + + are_activities_compatible(Tret, Ta, Tx) || continue + atol = rtol = sqrt(eps(real(T))) + + if T <: Complex && Ta <: Active && Tx <: BatchDuplicated + # avoid failure that crashes Julia + @test false skip = true + continue + end + + @testset "BLAS.scal!(n, a, x, incx)" begin + @testset for (sz, inc) in ((10, 1), ((2, 20), -2)) + a = randn(T) + x = randn(T, sz) + @test !fails() do + test_reverse( + Tret, n, (a, Ta), (x, Tx), inc; atol, rtol + ) do n, a, x, inc + BLAS.scal!(n, a, x, inc) + return nothing + end + end broken = (Tx <: BatchDuplicated && sz isa Int) end end @testset "BLAS.scal!(a, x)" begin a = randn(T) x = randn(T, n) - test_forward(BLAS.scal!, Tret, (a, Ta), (x, Tx); atol, rtol) + @test !fails() do + test_reverse(Tret, (a, Ta), (x, Tx); atol, rtol) do a, x + BLAS.scal!(a, x) + return nothing + end + end broken = (Tx <: BatchDuplicated) end end end From 0cdd0abbd8f0f861e6616cf7a495d3208bfdf3eb Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 16:50:42 +0200 Subject: [PATCH 08/15] Also test Const case --- test/blas.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/blas.jl b/test/blas.jl index 43421b6417..61bb701a07 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -50,7 +50,7 @@ using Test @testset "reverse" begin @testset for Tret in (Const,), Ta in (Const, Active), - Tx in (Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), T in BLASFloats are_activities_compatible(Tret, Ta, Tx) || continue From 0c1030a54d69b5220c9cc72940862a3ff84be968 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 18:44:21 +0200 Subject: [PATCH 09/15] Add tests for `BLAS.axpy!` --- test/blas.jl | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/test/blas.jl b/test/blas.jl index 61bb701a07..9585ad6d6a 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -148,4 +148,62 @@ using Test end end end + + @testset "BLAS.axpy!" begin + @testset "forward" begin + @testset for Tret in ( + Const, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + ), + Ta in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + Ty in (Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Ta, Tx, Ty) || continue + + @testset for T in BLASFloats, sz in (10, (2, 5), (3, 4, 5)) + a = randn(T) + x = randn(T, sz) + y = randn(T, sz) + atol = rtol = sqrt(eps(real(T))) + @test !fails() do + test_forward( + BLAS.axpy!, Tret, (a, Ta), (x, Tx), (y, Ty); atol, rtol + ) + end broken = (T <: ComplexF32 && !(Ta <: Const) && !(Ty <: Const)) + end + end + end + + @testset "reverse" begin + @testset for Tret in (Const,), + Ta in (Const, Active), + Tx in (Const, Duplicated, BatchDuplicated), + Ty in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Ta, Tx, Ty) || continue + + @testset for T in BLASFloats, sz in (10, (2, 5), (3, 4, 5)) + a = randn(T) + x = randn(T, sz) + y = randn(T, sz) + atol = rtol = sqrt(eps(real(T))) + @test !fails() do + test_reverse(Tret, (a, Ta), (x, Tx), (y, Ty); atol, rtol) do a, x, y + BLAS.axpy!(a, x, y) + return nothing + end + end broken = ( + T <: ComplexF32 && + xor(Tx <: BatchDuplicated, Ty <: BatchDuplicated) && + !(Ta <: Const) && + sz isa Int + ) + end + end + end + end end From fd6b5a6552127edb4f369c02c0244d9ac74d1fa9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 21:13:38 +0200 Subject: [PATCH 10/15] Add tests for `BLAS.gemv!` --- test/blas.jl | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/test/blas.jl b/test/blas.jl index 9585ad6d6a..25debea722 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -206,4 +206,95 @@ using Test end end end + + @testset "BLAS.gemv!" begin + @testset "forward" begin + @testset for Tret in ( + Const, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + ), + Talpha in (Const, Duplicated, BatchDuplicated), + TA in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + Tbeta in (Const, Duplicated, BatchDuplicated), + Ty in (Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Talpha, TA, Tx, Tbeta, Ty) || continue + + @testset for T in BLASFloats, t in ('N', 'T', 'C') + sz = (2, 3) + alpha, beta = randn(T, 2) + A = t === 'N' ? randn(T, sz...) : randn(T, reverse(sz)...) + x = randn(T, sz[2]) + y = randn(T, sz[1]) + atol = rtol = sqrt(eps(real(T))) + @test !fails() do + test_forward( + BLAS.gemv!, + Tret, + t, + (alpha, Talpha), + (A, TA), + (x, Tx), + (beta, Tbeta), + (y, Ty); + atol, + rtol, + ) + end broken = ( + T <: ComplexF32 && + !(Ty <: Const) && + !(Talpha <: Const && Tbeta <: Const) + ) + end + end + end + + @testset "reverse" begin + @testset for Tret in (Const,), + Talpha in (Const, Active), + TA in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + Tbeta in (Const, Active), + Ty in (Const, Duplicated, BatchDuplicated), + T in BLASFloats + + are_activities_compatible(Tret, Talpha, TA, Tx, Tbeta, Ty) || continue + + if T <: Complex && any(Base.Fix2(<:, BatchDuplicated), (TA, Tx, Ty)) + # avoid failure that crashes Julia + @test false skip = true + continue + end + + @testset for t in ('N', 'T', 'C') + sz = (2, 3) + alpha, beta = randn(T, 2) + A = t === 'N' ? randn(T, sz...) : randn(T, reverse(sz)...) + x = randn(T, sz[2]) + y = randn(T, sz[1]) + atol = rtol = sqrt(eps(real(T))) + @test !fails() do + test_reverse( + Tret, + t, + (alpha, Talpha), + (A, TA), + (x, Tx), + (beta, Tbeta), + (y, Ty); + atol, + rtol, + ) do args... + BLAS.gemv!(args...) + return nothing + end + end broken = any(Base.Fix2(<:, BatchDuplicated), (Tx, Ty)) + end + end + end + end end From 40bc7b8280fe4230f104d27fdef9116b4691a5c4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 21:46:00 +0200 Subject: [PATCH 11/15] Add tests for `BLAS.spmv!` --- test/blas.jl | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/test/blas.jl b/test/blas.jl index 25debea722..16697cf23c 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -297,4 +297,86 @@ using Test end end end + + @testset "BLAS.spmv!" begin + @testset "forward" begin + @testset for Tret in ( + Const, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + ), + Talpha in (Const, Duplicated, BatchDuplicated), + TAP in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + Tbeta in (Const, Duplicated, BatchDuplicated), + Ty in (Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Talpha, TAP, Tx, Tbeta, Ty) || continue + + n = 5 + m = div(n * (n + 1), 2) + + @testset for T in BLASReals, uplo in ('U', 'L') + alpha, beta = randn(T, 2) + AP = randn(T, m) + x = randn(T, n) + y = randn(T, n) + atol = rtol = sqrt(eps(real(T))) + test_forward( + BLAS.spmv!, + Tret, + uplo, + (alpha, Talpha), + (AP, TAP), + (x, Tx), + (beta, Tbeta), + (y, Ty); + atol, + rtol, + ) + end + end + end + + @testset "reverse" begin + @testset for Tret in (Const,), + Talpha in (Const, Active), + TAP in (Const, Duplicated, BatchDuplicated), + Tx in (Const, Duplicated, BatchDuplicated), + Tbeta in (Const, Active), + Ty in (Const, Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Talpha, TAP, Tx, Tbeta, Ty) || continue + + n = 5 + m = div(n * (n + 1), 2) + + @testset for T in BLASReals, uplo in ('U', 'L') + alpha, beta = randn(T, 2) + AP = randn(T, m) + x = randn(T, n) + y = randn(T, n) + atol = rtol = sqrt(eps(real(T))) + @test !fails() do + test_reverse( + Tret, + uplo, + (alpha, Talpha), + (AP, TAP), + (x, Tx), + (beta, Tbeta), + (y, Ty); + atol, + rtol, + ) do uplo, alpha, AP, x, beta, y + BLAS.spmv!(uplo, alpha, AP, x, beta, y) + return nothing + end + end broken = TAP <: BatchDuplicated + end + end + end + end end From 505df3abf0f14876de6590eb7316228c7785594b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 22:58:47 +0200 Subject: [PATCH 12/15] Add `BLAS.gemm!` tests --- test/blas.jl | 93 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/test/blas.jl b/test/blas.jl index 16697cf23c..6ed51ca243 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -379,4 +379,97 @@ using Test end end end + + @testset "BLAS.gemm!" begin + @testset "forward" begin + @testset for Tret in ( + Const, + Duplicated, + DuplicatedNoNeed, + BatchDuplicated, + BatchDuplicatedNoNeed, + ), + Talpha in (Const, Duplicated, BatchDuplicated), + TA in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated), + Tbeta in (Const, Duplicated, BatchDuplicated), + TC in (Duplicated, BatchDuplicated) + + are_activities_compatible(Tret, Talpha, TA, TB, Tbeta, TC) || continue + + @testset for T in BLASFloats, tA in ('N', 'T', 'C'), tB in ('N', 'T', 'C') + szA = (2, 3) + szB = (3, 4) + alpha, beta = randn(T, 2) + A = tA === 'N' ? randn(T, szA...) : randn(T, reverse(szA)...) + B = tB === 'N' ? randn(T, szB...) : randn(T, reverse(szB)...) + C = randn(T, (szA[1], szB[2])) + atol = rtol = sqrt(eps(real(T))) + @test !fails() do + test_forward( + BLAS.gemm!, + Tret, + tA, + tB, + (alpha, Talpha), + (A, TA), + (B, TB), + (beta, Tbeta), + (C, TC); + atol, + rtol, + ) + end broken = T <: ComplexF32 && !(Talpha <: Const && Tbeta <: Const) + end + end + end + + @testset "reverse" begin + @testset for Tret in (Const,), + Talpha in (Const, Active), + TA in (Const, Duplicated, BatchDuplicated), + TB in (Const, Duplicated, BatchDuplicated), + Tbeta in (Const, Active), + TC in (Duplicated, BatchDuplicated), + T in BLASFloats + + are_activities_compatible(Tret, Talpha, TA, TB, Tbeta, TC) || continue + + if T <: Complex && any(Base.Fix2(<:, BatchDuplicated), (TA, TB, TC)) + # avoid failure that crashes Julia + @test false skip = true + continue + end + + @testset for tA in ('N', 'T', 'C'), tB in ('N', 'T', 'C') + szA = (2, 3) + szB = (3, 4) + alpha, beta = randn(T, 2) + A = tA === 'N' ? randn(T, szA...) : randn(T, reverse(szA)...) + B = tB === 'N' ? randn(T, szB...) : randn(T, reverse(szB)...) + C = randn(T, (szA[1], szB[2])) + atol = rtol = sqrt(eps(real(T))) + @test !fails() do + test_reverse( + Tret, + tA, + tB, + (alpha, Talpha), + (A, TA), + (B, TB), + (beta, Tbeta), + (C, TC); + atol, + rtol, + ) do args... + BLAS.gemm!(args...) + return nothing + end + end broken = ( + T <: Complex && any(Base.Fix2(<:, BatchDuplicated), (TB, TC)) + ) + end + end + end + end end From b22187e9fefc0fe123a7757aadebe9b633853531 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 23:02:07 +0200 Subject: [PATCH 13/15] Simplify size definitions --- test/blas.jl | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/test/blas.jl b/test/blas.jl index 6ed51ca243..0e274f9376 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -8,9 +8,9 @@ using Test BLASReals = (Float32, Float64) BLASComplexes = (ComplexF32, ComplexF64) BLASFloats = (BLASReals..., BLASComplexes...) - n = 10 @testset "BLAS.scal!" begin + n = 10 @testset "forward" begin @testset for Tret in ( Const, @@ -27,7 +27,7 @@ using Test atol = rtol = sqrt(eps(real(T))) @testset "BLAS.scal!(n, a, x, incx)" begin - @testset for (sz, inc) in ((10, 1), ((2, 20), -2)) + @testset for (sz, inc) in ((n, 1), ((2, 2n), -2)) a = randn(T) x = randn(T, sz) @test !fails() do @@ -63,7 +63,7 @@ using Test end @testset "BLAS.scal!(n, a, x, incx)" begin - @testset for (sz, inc) in ((10, 1), ((2, 20), -2)) + @testset for (sz, inc) in ((n, 1), ((2, 2n), -2)) a = randn(T) x = randn(T, sz) @test !fails() do @@ -93,6 +93,7 @@ using Test @testset for fun in (BLAS.dot, BLAS.dotu, BLAS.dotc) @testset "forward" begin + n = 10 @testset for Tret in ( Const, Duplicated, @@ -108,7 +109,7 @@ using Test atol = rtol = sqrt(eps(real(T))) @testset "$fun(n, x, incx, y, incy)" begin - @testset for (sz, inc) in ((10, 1), ((2, 20), -2)) + @testset for (sz, inc) in ((n, 1), ((2, 2n), -2)) x = randn(T, sz) y = randn(T, sz) test_forward(fun, Tret, n, (x, Tx), inc, (y, Ty), inc; atol, rtol) @@ -133,7 +134,7 @@ using Test atol = rtol = sqrt(eps(real(T))) @testset "$fun(n, x, incx, y, incy)" begin - @testset for (sz, inc) in ((10, 1), ((2, 20), -2)) + @testset for (sz, inc) in ((n, 1), ((2, 2n), -2)) x = randn(T, sz) y = randn(T, sz) test_reverse(fun, Tret, n, (x, Tx), inc, (y, Ty), inc; atol, rtol) @@ -208,6 +209,7 @@ using Test end @testset "BLAS.gemv!" begin + sz = (2, 3) @testset "forward" begin @testset for Tret in ( Const, @@ -225,7 +227,6 @@ using Test are_activities_compatible(Tret, Talpha, TA, Tx, Tbeta, Ty) || continue @testset for T in BLASFloats, t in ('N', 'T', 'C') - sz = (2, 3) alpha, beta = randn(T, 2) A = t === 'N' ? randn(T, sz...) : randn(T, reverse(sz)...) x = randn(T, sz[2]) @@ -271,7 +272,6 @@ using Test end @testset for t in ('N', 'T', 'C') - sz = (2, 3) alpha, beta = randn(T, 2) A = t === 'N' ? randn(T, sz...) : randn(T, reverse(sz)...) x = randn(T, sz[2]) @@ -299,6 +299,8 @@ using Test end @testset "BLAS.spmv!" begin + n = 5 + m = div(n * (n + 1), 2) @testset "forward" begin @testset for Tret in ( Const, @@ -315,9 +317,6 @@ using Test are_activities_compatible(Tret, Talpha, TAP, Tx, Tbeta, Ty) || continue - n = 5 - m = div(n * (n + 1), 2) - @testset for T in BLASReals, uplo in ('U', 'L') alpha, beta = randn(T, 2) AP = randn(T, m) @@ -350,9 +349,6 @@ using Test are_activities_compatible(Tret, Talpha, TAP, Tx, Tbeta, Ty) || continue - n = 5 - m = div(n * (n + 1), 2) - @testset for T in BLASReals, uplo in ('U', 'L') alpha, beta = randn(T, 2) AP = randn(T, m) @@ -381,6 +377,8 @@ using Test end @testset "BLAS.gemm!" begin + szA = (2, 3) + szB = (3, 4) @testset "forward" begin @testset for Tret in ( Const, @@ -398,8 +396,6 @@ using Test are_activities_compatible(Tret, Talpha, TA, TB, Tbeta, TC) || continue @testset for T in BLASFloats, tA in ('N', 'T', 'C'), tB in ('N', 'T', 'C') - szA = (2, 3) - szB = (3, 4) alpha, beta = randn(T, 2) A = tA === 'N' ? randn(T, szA...) : randn(T, reverse(szA)...) B = tB === 'N' ? randn(T, szB...) : randn(T, reverse(szB)...) @@ -442,8 +438,6 @@ using Test end @testset for tA in ('N', 'T', 'C'), tB in ('N', 'T', 'C') - szA = (2, 3) - szB = (3, 4) alpha, beta = randn(T, 2) A = tA === 'N' ? randn(T, szA...) : randn(T, reverse(szA)...) B = tB === 'N' ? randn(T, szB...) : randn(T, reverse(szB)...) From fe8fd0fff9f16f502a652ee5db1746eceaecf024 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 23:58:09 +0200 Subject: [PATCH 14/15] Avoid anonymous functions --- test/blas.jl | 56 ++++++++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/test/blas.jl b/test/blas.jl index 0e274f9376..8d8e2b99fd 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -4,6 +4,8 @@ using LinearAlgebra using MetaTesting using Test +discard(_) = nothing + @testset "BLAS rules" begin BLASReals = (Float32, Float64) BLASComplexes = (ComplexF32, ComplexF64) @@ -68,11 +70,15 @@ using Test x = randn(T, sz) @test !fails() do test_reverse( - Tret, n, (a, Ta), (x, Tx), inc; atol, rtol - ) do n, a, x, inc - BLAS.scal!(n, a, x, inc) - return nothing - end + discard ∘ BLAS.scal!, + Tret, + n, + (a, Ta), + (x, Tx), + inc; + atol, + rtol, + ) end broken = (Tx <: BatchDuplicated && sz isa Int) end end @@ -81,10 +87,9 @@ using Test a = randn(T) x = randn(T, n) @test !fails() do - test_reverse(Tret, (a, Ta), (x, Tx); atol, rtol) do a, x - BLAS.scal!(a, x) - return nothing - end + test_reverse( + discard ∘ BLAS.scal!, Tret, (a, Ta), (x, Tx); atol, rtol + ) end broken = (Tx <: BatchDuplicated) end end @@ -92,8 +97,8 @@ using Test end @testset for fun in (BLAS.dot, BLAS.dotu, BLAS.dotc) + n = 10 @testset "forward" begin - n = 10 @testset for Tret in ( Const, Duplicated, @@ -193,10 +198,15 @@ using Test y = randn(T, sz) atol = rtol = sqrt(eps(real(T))) @test !fails() do - test_reverse(Tret, (a, Ta), (x, Tx), (y, Ty); atol, rtol) do a, x, y - BLAS.axpy!(a, x, y) - return nothing - end + test_reverse( + discard ∘ BLAS.axpy!, + Tret, + (a, Ta), + (x, Tx), + (y, Ty); + atol, + rtol, + ) end broken = ( T <: ComplexF32 && xor(Tx <: BatchDuplicated, Ty <: BatchDuplicated) && @@ -279,6 +289,7 @@ using Test atol = rtol = sqrt(eps(real(T))) @test !fails() do test_reverse( + discard ∘ BLAS.gemv!, Tret, t, (alpha, Talpha), @@ -288,10 +299,7 @@ using Test (y, Ty); atol, rtol, - ) do args... - BLAS.gemv!(args...) - return nothing - end + ) end broken = any(Base.Fix2(<:, BatchDuplicated), (Tx, Ty)) end end @@ -357,6 +365,7 @@ using Test atol = rtol = sqrt(eps(real(T))) @test !fails() do test_reverse( + discard ∘ BLAS.spmv!, Tret, uplo, (alpha, Talpha), @@ -366,10 +375,7 @@ using Test (y, Ty); atol, rtol, - ) do uplo, alpha, AP, x, beta, y - BLAS.spmv!(uplo, alpha, AP, x, beta, y) - return nothing - end + ) end broken = TAP <: BatchDuplicated end end @@ -445,6 +451,7 @@ using Test atol = rtol = sqrt(eps(real(T))) @test !fails() do test_reverse( + discard ∘ BLAS.gemm!, Tret, tA, tB, @@ -455,10 +462,7 @@ using Test (C, TC); atol, rtol, - ) do args... - BLAS.gemm!(args...) - return nothing - end + ) end broken = ( T <: Complex && any(Base.Fix2(<:, BatchDuplicated), (TB, TC)) ) From 821ee95202cea78a6643d61ec9d82c8cacd2bc5f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Sep 2023 23:58:32 +0200 Subject: [PATCH 15/15] Unmark tests as broken --- test/blas.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/test/blas.jl b/test/blas.jl index 8d8e2b99fd..aff56c33ae 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -207,12 +207,7 @@ discard(_) = nothing atol, rtol, ) - end broken = ( - T <: ComplexF32 && - xor(Tx <: BatchDuplicated, Ty <: BatchDuplicated) && - !(Ta <: Const) && - sz isa Int - ) + end end end end