|
65 | 65 |
|
66 | 66 | @testset "Diagonal" begin |
67 | 67 | # fwd |
68 | | - @gpu test_frule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0])) |
69 | | - @gpu test_frule(*, Diagonal([1.0, 2.0, 3.0]), rand(3)) |
| 68 | + # Use size 4 to avoid Julia's 2x2/3x3 matmul fast path which |
| 69 | + # uses scalar indexing incompatible with GPU arrays |
| 70 | + @gpu test_frule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), Diagonal([4.0, 5.0, 6.0, 7.0])) |
| 71 | + @gpu test_frule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), rand(4)) |
70 | 72 |
|
71 | 73 | # rev |
72 | | - @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), Diagonal([4.0, 5.0, 6.0])) |
73 | | - @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3)) |
| 74 | + @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), Diagonal([4.0, 5.0, 6.0, 7.0])) |
| 75 | + @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), rand(4)) |
74 | 76 |
|
75 | 77 | # Needs to not try and inplace, as `mul!` will do wrong. |
76 | 78 | # see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411 |
77 | | - @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0]), rand(3,3)) |
| 79 | + @gpu test_rrule(*, Diagonal([1.0, 2.0, 3.0, 4.0]), rand(4,4)) |
78 | 80 | end |
79 | 81 |
|
80 | 82 | @testset "$adj * Vector" for adj in (adjoint, transpose) |
|
83 | 85 | end |
84 | 86 | end |
85 | 87 |
|
| 88 | + # Use size 4 to avoid Julia's 2x2/3x3 matmul fast path which |
| 89 | + # uses scalar indexing incompatible with GPU arrays (JLArrays) |
86 | 90 | @testset "muladd: $T" for T in (Float64, ComplexF64) |
87 | | - @testset "add $(typeof(z))" for z in [rand(), rand(T, 3), rand(T, 3, 3), false] |
| 91 | + @testset "add $(typeof(z))" for z in [rand(), rand(T, 4), rand(T, 4, 4), false] |
88 | 92 | @testset "matrix * matrix" begin |
89 | | - A = rand(T, 3, 3) |
90 | | - B = rand(T, 3, 3) |
| 93 | + A = rand(T, 4, 4) |
| 94 | + B = rand(T, 4, 4) |
91 | 95 | @gpu test_rrule(muladd, A, B, z) |
92 | 96 | @gpu test_rrule(muladd, A', B, z) |
93 | 97 | @gpu test_rrule(muladd, A , B', z) |
94 | 98 | @gpu test_frule(muladd, A, B, z) |
95 | 99 | @gpu test_frule(muladd, A', B, z) |
96 | 100 | @gpu test_frule(muladd, A , B', z) |
97 | 101 |
|
98 | | - C = rand(T, 3, 5) |
99 | | - D = rand(T, 5, 3) |
| 102 | + C = rand(T, 4, 5) |
| 103 | + D = rand(T, 5, 4) |
100 | 104 | @gpu test_rrule(muladd, C, D, z) |
101 | 105 | @gpu test_frule(muladd, C, D, z) |
102 | 106 | end |
103 | 107 | if ndims(z) <= 1 |
104 | 108 | @testset "matrix * vector" begin |
105 | | - A, B = rand(T, 3, 3), rand(T, 3) |
| 109 | + A, B = rand(T, 4, 4), rand(T, 4) |
106 | 110 | test_rrule(muladd, A, B, z) |
107 | | - test_rrule(muladd, A, B ⊢ rand(T, 3,1), z) |
| 111 | + test_rrule(muladd, A, B ⊢ rand(T, 4,1), z) |
108 | 112 | test_frule(muladd, A, B, z) |
109 | 113 | end |
110 | 114 | @testset "adjoint * matrix" begin |
111 | | - At, B = rand(T, 3)', rand(T, 3, 3) |
| 115 | + At, B = rand(T, 4)', rand(T, 4, 4) |
112 | 116 | test_rrule(muladd, At, B, z') |
113 | | - test_rrule(muladd, At ⊢ rand(T,1,3), B, z') |
| 117 | + test_rrule(muladd, At ⊢ rand(T,1,4), B, z') |
114 | 118 | test_frule(muladd, At, B, z') |
115 | 119 | end |
116 | 120 | end |
117 | 121 | if ndims(z) == 0 |
118 | 122 | @testset "adjoint * vector" begin # like dot |
119 | | - A, B = rand(T, 3)', rand(T, 3) |
| 123 | + A, B = rand(T, 4)', rand(T, 4) |
120 | 124 | test_rrule(muladd, A, B, z) |
121 | | - test_rrule(muladd, A ⊢ rand(T,1,3), B, z') |
| 125 | + test_rrule(muladd, A ⊢ rand(T,1,4), B, z') |
122 | 126 | test_frule(muladd, A, B, z) |
123 | 127 | end |
124 | 128 | end |
125 | 129 | if ndims(z) == 2 # other dims lead to e.g. muladd(ones(4), ones(1,4), 1) |
126 | 130 | @testset "vector * adjoint" begin # outer product |
127 | | - A, B = rand(T, 3), rand(T, 3)' |
| 131 | + A, B = rand(T, 4), rand(T, 4)' |
128 | 132 | test_rrule(muladd, A, B, z) |
129 | | - test_rrule(muladd, A, B ⊢ rand(T,1,3), z) |
| 133 | + test_rrule(muladd, A, B ⊢ rand(T,1,4), z) |
130 | 134 | test_frule(muladd, A, B, z) |
131 | 135 | end |
132 | 136 | end |
|
0 commit comments