1 | // |
2 | // Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved. |
3 | // Use of this source code is governed by a BSD-style license that can be |
4 | // found in the LICENSE file. |
5 | // |
6 | |
7 | #include "compiler/translator/ValidateLimitations.h" |
8 | |
9 | #include "angle_gl.h" |
10 | #include "compiler/translator/Diagnostics.h" |
11 | #include "compiler/translator/ParseContext.h" |
12 | #include "compiler/translator/tree_util/IntermTraverse.h" |
13 | |
14 | namespace sh |
15 | { |
16 | |
17 | namespace |
18 | { |
19 | |
20 | int GetLoopSymbolId(TIntermLoop *loop) |
21 | { |
22 | // Here we assume all the operations are valid, because the loop node is |
23 | // already validated before this call. |
24 | TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence(); |
25 | TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode(); |
26 | TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode(); |
27 | |
28 | return symbol->uniqueId().get(); |
29 | } |
30 | |
31 | // Traverses a node to check if it represents a constant index expression. |
32 | // Definition: |
33 | // constant-index-expressions are a superset of constant-expressions. |
34 | // Constant-index-expressions can include loop indices as defined in |
35 | // GLSL ES 1.0 spec, Appendix A, section 4. |
36 | // The following are constant-index-expressions: |
37 | // - Constant expressions |
38 | // - Loop indices as defined in section 4 |
39 | // - Expressions composed of both of the above |
40 | class ValidateConstIndexExpr : public TIntermTraverser |
41 | { |
42 | public: |
43 | ValidateConstIndexExpr(const std::vector<int> &loopSymbols) |
44 | : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols) |
45 | {} |
46 | |
47 | // Returns true if the parsed node represents a constant index expression. |
48 | bool isValid() const { return mValid; } |
49 | |
50 | void visitSymbol(TIntermSymbol *symbol) override |
51 | { |
52 | // Only constants and loop indices are allowed in a |
53 | // constant index expression. |
54 | if (mValid) |
55 | { |
56 | bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), |
57 | symbol->uniqueId().get()) != mLoopSymbolIds.end(); |
58 | mValid = (symbol->getQualifier() == EvqConst) || isLoopSymbol; |
59 | } |
60 | } |
61 | |
62 | private: |
63 | bool mValid; |
64 | const std::vector<int> mLoopSymbolIds; |
65 | }; |
66 | |
67 | // Traverses intermediate tree to ensure that the shader does not exceed the |
68 | // minimum functionality mandated in GLSL 1.0 spec, Appendix A. |
69 | class ValidateLimitationsTraverser : public TLValueTrackingTraverser |
70 | { |
71 | public: |
72 | ValidateLimitationsTraverser(sh::GLenum shaderType, |
73 | TSymbolTable *symbolTable, |
74 | TDiagnostics *diagnostics); |
75 | |
76 | void visitSymbol(TIntermSymbol *node) override; |
77 | bool visitBinary(Visit, TIntermBinary *) override; |
78 | bool visitLoop(Visit, TIntermLoop *) override; |
79 | |
80 | private: |
81 | void error(TSourceLoc loc, const char *reason, const char *token); |
82 | void error(TSourceLoc loc, const char *reason, const ImmutableString &token); |
83 | |
84 | bool isLoopIndex(TIntermSymbol *symbol); |
85 | bool validateLoopType(TIntermLoop *node); |
86 | |
87 | bool validateForLoopHeader(TIntermLoop *node); |
88 | // If valid, return the index symbol id; Otherwise, return -1. |
89 | int validateForLoopInit(TIntermLoop *node); |
90 | bool validateForLoopCond(TIntermLoop *node, int indexSymbolId); |
91 | bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId); |
92 | |
93 | // Returns true if indexing does not exceed the minimum functionality |
94 | // mandated in GLSL 1.0 spec, Appendix A, Section 5. |
95 | bool isConstExpr(TIntermNode *node); |
96 | bool isConstIndexExpr(TIntermNode *node); |
97 | bool validateIndexing(TIntermBinary *node); |
98 | |
99 | sh::GLenum mShaderType; |
100 | TDiagnostics *mDiagnostics; |
101 | std::vector<int> mLoopSymbolIds; |
102 | }; |
103 | |
104 | ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType, |
105 | TSymbolTable *symbolTable, |
106 | TDiagnostics *diagnostics) |
107 | : TLValueTrackingTraverser(true, false, false, symbolTable), |
108 | mShaderType(shaderType), |
109 | mDiagnostics(diagnostics) |
110 | { |
111 | ASSERT(diagnostics); |
112 | } |
113 | |
114 | void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node) |
115 | { |
116 | if (isLoopIndex(node) && isLValueRequiredHere()) |
117 | { |
118 | error(node->getLine(), |
119 | "Loop index cannot be statically assigned to within the body of the loop" , |
120 | node->getName()); |
121 | } |
122 | } |
123 | |
124 | bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node) |
125 | { |
126 | // Check indexing. |
127 | switch (node->getOp()) |
128 | { |
129 | case EOpIndexDirect: |
130 | case EOpIndexIndirect: |
131 | validateIndexing(node); |
132 | break; |
133 | default: |
134 | break; |
135 | } |
136 | return true; |
137 | } |
138 | |
139 | bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node) |
140 | { |
141 | if (!validateLoopType(node)) |
142 | return false; |
143 | |
144 | if (!validateForLoopHeader(node)) |
145 | return false; |
146 | |
147 | TIntermNode *body = node->getBody(); |
148 | if (body != nullptr) |
149 | { |
150 | mLoopSymbolIds.push_back(GetLoopSymbolId(node)); |
151 | body->traverse(this); |
152 | mLoopSymbolIds.pop_back(); |
153 | } |
154 | |
155 | // The loop is fully processed - no need to visit children. |
156 | return false; |
157 | } |
158 | |
159 | void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token) |
160 | { |
161 | mDiagnostics->error(loc, reason, token); |
162 | } |
163 | |
164 | void ValidateLimitationsTraverser::error(TSourceLoc loc, |
165 | const char *reason, |
166 | const ImmutableString &token) |
167 | { |
168 | error(loc, reason, token.data()); |
169 | } |
170 | |
171 | bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol) |
172 | { |
173 | return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->uniqueId().get()) != |
174 | mLoopSymbolIds.end(); |
175 | } |
176 | |
177 | bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node) |
178 | { |
179 | TLoopType type = node->getType(); |
180 | if (type == ELoopFor) |
181 | return true; |
182 | |
183 | // Reject while and do-while loops. |
184 | error(node->getLine(), "This type of loop is not allowed" , type == ELoopWhile ? "while" : "do" ); |
185 | return false; |
186 | } |
187 | |
188 | bool ValidateLimitationsTraverser::(TIntermLoop *node) |
189 | { |
190 | ASSERT(node->getType() == ELoopFor); |
191 | |
192 | // |
193 | // The for statement has the form: |
194 | // for ( init-declaration ; condition ; expression ) statement |
195 | // |
196 | int indexSymbolId = validateForLoopInit(node); |
197 | if (indexSymbolId < 0) |
198 | return false; |
199 | if (!validateForLoopCond(node, indexSymbolId)) |
200 | return false; |
201 | if (!validateForLoopExpr(node, indexSymbolId)) |
202 | return false; |
203 | |
204 | return true; |
205 | } |
206 | |
207 | int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node) |
208 | { |
209 | TIntermNode *init = node->getInit(); |
210 | if (init == nullptr) |
211 | { |
212 | error(node->getLine(), "Missing init declaration" , "for" ); |
213 | return -1; |
214 | } |
215 | |
216 | // |
217 | // init-declaration has the form: |
218 | // type-specifier identifier = constant-expression |
219 | // |
220 | TIntermDeclaration *decl = init->getAsDeclarationNode(); |
221 | if (decl == nullptr) |
222 | { |
223 | error(init->getLine(), "Invalid init declaration" , "for" ); |
224 | return -1; |
225 | } |
226 | // To keep things simple do not allow declaration list. |
227 | TIntermSequence *declSeq = decl->getSequence(); |
228 | if (declSeq->size() != 1) |
229 | { |
230 | error(decl->getLine(), "Invalid init declaration" , "for" ); |
231 | return -1; |
232 | } |
233 | TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode(); |
234 | if ((declInit == nullptr) || (declInit->getOp() != EOpInitialize)) |
235 | { |
236 | error(decl->getLine(), "Invalid init declaration" , "for" ); |
237 | return -1; |
238 | } |
239 | TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode(); |
240 | if (symbol == nullptr) |
241 | { |
242 | error(declInit->getLine(), "Invalid init declaration" , "for" ); |
243 | return -1; |
244 | } |
245 | // The loop index has type int or float. |
246 | TBasicType type = symbol->getBasicType(); |
247 | if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat)) |
248 | { |
249 | error(symbol->getLine(), "Invalid type for loop index" , getBasicString(type)); |
250 | return -1; |
251 | } |
252 | // The loop index is initialized with constant expression. |
253 | if (!isConstExpr(declInit->getRight())) |
254 | { |
255 | error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression" , |
256 | symbol->getName()); |
257 | return -1; |
258 | } |
259 | |
260 | return symbol->uniqueId().get(); |
261 | } |
262 | |
263 | bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId) |
264 | { |
265 | TIntermNode *cond = node->getCondition(); |
266 | if (cond == nullptr) |
267 | { |
268 | error(node->getLine(), "Missing condition" , "for" ); |
269 | return false; |
270 | } |
271 | // |
272 | // condition has the form: |
273 | // loop_index relational_operator constant_expression |
274 | // |
275 | TIntermBinary *binOp = cond->getAsBinaryNode(); |
276 | if (binOp == nullptr) |
277 | { |
278 | error(node->getLine(), "Invalid condition" , "for" ); |
279 | return false; |
280 | } |
281 | // Loop index should be to the left of relational operator. |
282 | TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode(); |
283 | if (symbol == nullptr) |
284 | { |
285 | error(binOp->getLine(), "Invalid condition" , "for" ); |
286 | return false; |
287 | } |
288 | if (symbol->uniqueId().get() != indexSymbolId) |
289 | { |
290 | error(symbol->getLine(), "Expected loop index" , symbol->getName()); |
291 | return false; |
292 | } |
293 | // Relational operator is one of: > >= < <= == or !=. |
294 | switch (binOp->getOp()) |
295 | { |
296 | case EOpEqual: |
297 | case EOpNotEqual: |
298 | case EOpLessThan: |
299 | case EOpGreaterThan: |
300 | case EOpLessThanEqual: |
301 | case EOpGreaterThanEqual: |
302 | break; |
303 | default: |
304 | error(binOp->getLine(), "Invalid relational operator" , |
305 | GetOperatorString(binOp->getOp())); |
306 | break; |
307 | } |
308 | // Loop index must be compared with a constant. |
309 | if (!isConstExpr(binOp->getRight())) |
310 | { |
311 | error(binOp->getLine(), "Loop index cannot be compared with non-constant expression" , |
312 | symbol->getName()); |
313 | return false; |
314 | } |
315 | |
316 | return true; |
317 | } |
318 | |
319 | bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId) |
320 | { |
321 | TIntermNode *expr = node->getExpression(); |
322 | if (expr == nullptr) |
323 | { |
324 | error(node->getLine(), "Missing expression" , "for" ); |
325 | return false; |
326 | } |
327 | |
328 | // for expression has one of the following forms: |
329 | // loop_index++ |
330 | // loop_index-- |
331 | // loop_index += constant_expression |
332 | // loop_index -= constant_expression |
333 | // ++loop_index |
334 | // --loop_index |
335 | // The last two forms are not specified in the spec, but I am assuming |
336 | // its an oversight. |
337 | TIntermUnary *unOp = expr->getAsUnaryNode(); |
338 | TIntermBinary *binOp = unOp ? nullptr : expr->getAsBinaryNode(); |
339 | |
340 | TOperator op = EOpNull; |
341 | TIntermSymbol *symbol = nullptr; |
342 | if (unOp != nullptr) |
343 | { |
344 | op = unOp->getOp(); |
345 | symbol = unOp->getOperand()->getAsSymbolNode(); |
346 | } |
347 | else if (binOp != nullptr) |
348 | { |
349 | op = binOp->getOp(); |
350 | symbol = binOp->getLeft()->getAsSymbolNode(); |
351 | } |
352 | |
353 | // The operand must be loop index. |
354 | if (symbol == nullptr) |
355 | { |
356 | error(expr->getLine(), "Invalid expression" , "for" ); |
357 | return false; |
358 | } |
359 | if (symbol->uniqueId().get() != indexSymbolId) |
360 | { |
361 | error(symbol->getLine(), "Expected loop index" , symbol->getName()); |
362 | return false; |
363 | } |
364 | |
365 | // The operator is one of: ++ -- += -=. |
366 | switch (op) |
367 | { |
368 | case EOpPostIncrement: |
369 | case EOpPostDecrement: |
370 | case EOpPreIncrement: |
371 | case EOpPreDecrement: |
372 | ASSERT((unOp != nullptr) && (binOp == nullptr)); |
373 | break; |
374 | case EOpAddAssign: |
375 | case EOpSubAssign: |
376 | ASSERT((unOp == nullptr) && (binOp != nullptr)); |
377 | break; |
378 | default: |
379 | error(expr->getLine(), "Invalid operator" , GetOperatorString(op)); |
380 | return false; |
381 | } |
382 | |
383 | // Loop index must be incremented/decremented with a constant. |
384 | if (binOp != nullptr) |
385 | { |
386 | if (!isConstExpr(binOp->getRight())) |
387 | { |
388 | error(binOp->getLine(), "Loop index cannot be modified by non-constant expression" , |
389 | symbol->getName()); |
390 | return false; |
391 | } |
392 | } |
393 | |
394 | return true; |
395 | } |
396 | |
397 | bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node) |
398 | { |
399 | ASSERT(node != nullptr); |
400 | return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst; |
401 | } |
402 | |
403 | bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node) |
404 | { |
405 | ASSERT(node != nullptr); |
406 | |
407 | ValidateConstIndexExpr validate(mLoopSymbolIds); |
408 | node->traverse(&validate); |
409 | return validate.isValid(); |
410 | } |
411 | |
412 | bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node) |
413 | { |
414 | ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect)); |
415 | |
416 | bool valid = true; |
417 | TIntermTyped *index = node->getRight(); |
418 | // The index expession must be a constant-index-expression unless |
419 | // the operand is a uniform in a vertex shader. |
420 | TIntermTyped *operand = node->getLeft(); |
421 | bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform); |
422 | if (!skip && !isConstIndexExpr(index)) |
423 | { |
424 | error(index->getLine(), "Index expression must be constant" , "[]" ); |
425 | valid = false; |
426 | } |
427 | return valid; |
428 | } |
429 | |
430 | } // namespace |
431 | |
432 | bool ValidateLimitations(TIntermNode *root, |
433 | GLenum shaderType, |
434 | TSymbolTable *symbolTable, |
435 | TDiagnostics *diagnostics) |
436 | { |
437 | ValidateLimitationsTraverser validate(shaderType, symbolTable, diagnostics); |
438 | root->traverse(&validate); |
439 | return diagnostics->numErrors() == 0; |
440 | } |
441 | |
442 | } // namespace sh |
443 | |