forked from scala-lms/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstencil.scala
More file actions
176 lines (154 loc) · 6.04 KB
/
stencil.scala
File metadata and controls
176 lines (154 loc) · 6.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
/**
Sliding Stencil
===============
Shonan Challenge 3.3 Stencil
https://groups.google.com/d/msg/stagedhpc/r5L4xGJETwE/1kOdwJo0kgAJ
Outline:
<div id="tableofcontents"></div>
*/
package scala.lms.tutorial
import scala.lms.common.ForwardTransformer
import scala.reflect.SourceContext
// Shonan Challenge 3.3 Stencil
// https://groups.google.com/d/msg/stagedhpc/r5L4xGJETwE/1kOdwJo0kgAJ
trait Sliding extends Dsl {
def infix_sliding[T:Typ](n: Rep[Int], f: Rep[Int] => Rep[T]): Rep[Array[T]] = {
val a = NewArray[T](n)
sliding(0,n)(i => a(i) = f(i))
a
}
def infix_sliding(r: Rep[Range]) = new {
def foreach(f: Rep[Int] => Rep[Unit]): Rep[Unit] =
sliding(r.start, r.end)(f)
}
def sliding(start: Rep[Int], end: Rep[Int])(f: Rep[Int] => Rep[Unit]): Rep[Unit]
}
trait NoSlidingExp extends DslExp with Sliding {
// not actually sliding -- just to have a baseline reference
def sliding(start: Rep[Int], end: Rep[Int])(f: Rep[Int] => Rep[Unit]): Rep[Unit] = {
(start until end) foreach f
}
}
trait SlidingExp extends DslExp with Sliding {
object trans extends ForwardTransformer {
val IR: SlidingExp.this.type = SlidingExp.this
}
def log(x: Any): Unit = {
System.out.println("sliding log: "+x)
}
// some arithemetic rewrites to maximize sliding sharing
override def int_plus(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext): Exp[Int] = ((lhs,rhs) match {
case (Def(IntPlus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_plus(x, unit(y+z)) // (x+y)+z --> x+(y+z)
case (Def(IntMinus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_minus(x, unit(y-z)) // (x-y)+z --> x-(y-z)
case (x: Exp[Int], Const(z:Int)) if z < 0 => int_minus(x, unit(-z))
case _ => super.int_plus(lhs,rhs)
}).asInstanceOf[Exp[Int]]
override def int_minus(lhs: Exp[Int], rhs: Exp[Int])(implicit pos: SourceContext): Exp[Int] = ((lhs,rhs) match {
case (Def(IntMinus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_minus(x, unit(y+z)) // (x-y)-z --> x-(y+z)
case (Def(IntPlus(x:Exp[Int],Const(y:Int))), Const(z:Int)) => int_plus(x, unit(y-z)) // (x+y)-z --> x+(y-z)
case (x: Exp[Int], Const(z:Int)) if z < 0 => int_plus(x, unit(-z))
case _ => super.int_minus(lhs,rhs)
}).asInstanceOf[Exp[Int]]
def sliding(start: Rep[Int], end: Rep[Int])(f: Rep[Int] => Rep[Unit]): Rep[Unit] = {
val i = fresh[Int]
val save = context
// evaluate loop contents f(i)
val (r0,stms0) = reifySubGraph(f(i))
val (((r1,stms1,subst1),(r2,stms2,subst2)), _) = reifySubGraph {
reflectSubGraph(stms0)
context = save
// evaluate loop contents f(i+1)
val ((r1,subst1),stms1) = reifySubGraph(trans.withSubstScope(i -> (i+1)) {
stms0.foreach(s=>trans.traverseStm(s))
(trans(r0), trans.subst)
})
val ((r2,stms2,subst2), _) = reifySubGraph {
reflectSubGraph(stms1)
context = save
// evaluate loop contents f(i+2)
val ((r2,subst2),stms2) = reifySubGraph(trans.withSubstScope(i -> (i+2)) {
stms0.foreach(s=>trans.traverseStm(s))
(trans(r0), trans.subst)
})
(r2,stms2,subst2)
}
((r1,stms1,subst1), (r2,stms2,subst2))
}
context = save
val defs = stms0.flatMap(_.lhs)
// find overlap syms: defined by f(i), used by f(i+1) and f(i+2)
val overlap01 = stms1.flatMap { case TP(s,d) => syms(d) filter (defs contains _) }.distinct
val overlap02 = stms2.flatMap { case TP(s,d) => syms(d) filter (defs contains _) }.distinct
if (overlap02.nonEmpty)
log("Overlap beyond a single loop iteration is ignored, since not yet implemented.")
val overlap0 = (overlap01++overlap02).distinct
val overlap1 = overlap0 map subst1
// build a variable for each overlap sym.
// init variables by peeling first loop iteration.
if (end > start) {
val (rX,substX) = trans.withSubstScope(i -> start) {
stms0.foreach(s=>trans.traverseStm(s))
(trans(r0), trans.subst)
}
val vars = overlap0 map (x => var_new(substX(x))(x.tp,x.pos.head))
// now generate the loop
for (j <- (start + unit(1)) until end) {
// read the overlap variables
val reads = (overlap0 zip vars) map (p => (p._1, readVar(p._2)(p._1.tp,p._1.pos.head)))
// emit the transformed loop body
val (r,substY1) = trans.withSubstScope((reads:+(i->(j-unit(1)))): _*) {
stms1.foreach(s=>trans.traverseStm(s))
(trans(r1), trans.subst)
}
// write the new values to the overlap vars
val writes = (overlap1 zip vars) map (p => (p._1, var_assign(p._2, substY1(p._1))(p._1.tp,p._1.pos.head)))
}
}
}
}
trait SlidingWarmup extends Sliding {
def snippet(n: Rep[Int]): Rep[Array[Int]] = {
def compute(i: Rep[Int]) = 2*i+3
n sliding { i => compute(i) + compute(i+1) }
}
}
abstract class SlidingWarmupDriver extends DslDriver[Int,Array[Int]] with SlidingWarmup
class SlidingWarmupTest extends TutorialFunSuite {
val under = "sliding"
test("warmup without sliding") {
val sliding0 = new SlidingWarmupDriver with NoSlidingExp
check("0", sliding0.code)
}
test("warmup with sliding") {
val sliding1 = new SlidingWarmupDriver with SlidingExp
check("1", sliding1.code)
}
}
trait Stencil extends Sliding {
def snippet(v: Rep[Array[Double]]): Rep[Array[Double]] = {
val n = v.length
val input = v
val output = NewArray[Double](n)
def a(j: Rep[Int]) = input(j)
def w1(j: Rep[Int]) = a(j) * a(j+1)
def wm(j: Rep[Int]) = a(j) - w1(j) + w1(j-1)
def w2(j: Rep[Int]) = wm(j) * wm(j+1)
def b(j: Rep[Int]) = wm(j) - w2(j) + w2(j-1)
for (i <- (2 until n-2).sliding) {
output(i) = b(i)
}
output
}
}
abstract class StencilDriver extends DslDriver[Array[Double],Array[Double]] with Stencil
class StencilTest extends TutorialFunSuite {
val under = "stencil"
test("stencil without sliding") {
val stencil0 = new StencilDriver with NoSlidingExp
check("0", stencil0.code)
}
test("stencil with sliding") {
val stencil1 = new StencilDriver with SlidingExp
check("1", stencil1.code)
}
}