1//
2// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6
7#include "compiler/translator/ValidateLimitations.h"
8
9#include "angle_gl.h"
10#include "compiler/translator/Diagnostics.h"
11#include "compiler/translator/ParseContext.h"
12#include "compiler/translator/tree_util/IntermTraverse.h"
13
14namespace sh
15{
16
17namespace
18{
19
20int GetLoopSymbolId(TIntermLoop *loop)
21{
22 // Here we assume all the operations are valid, because the loop node is
23 // already validated before this call.
24 TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence();
25 TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
26 TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
27
28 return symbol->uniqueId().get();
29}
30
31// Traverses a node to check if it represents a constant index expression.
32// Definition:
33// constant-index-expressions are a superset of constant-expressions.
34// Constant-index-expressions can include loop indices as defined in
35// GLSL ES 1.0 spec, Appendix A, section 4.
36// The following are constant-index-expressions:
37// - Constant expressions
38// - Loop indices as defined in section 4
39// - Expressions composed of both of the above
40class ValidateConstIndexExpr : public TIntermTraverser
41{
42 public:
43 ValidateConstIndexExpr(const std::vector<int> &loopSymbols)
44 : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols)
45 {}
46
47 // Returns true if the parsed node represents a constant index expression.
48 bool isValid() const { return mValid; }
49
50 void visitSymbol(TIntermSymbol *symbol) override
51 {
52 // Only constants and loop indices are allowed in a
53 // constant index expression.
54 if (mValid)
55 {
56 bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(),
57 symbol->uniqueId().get()) != mLoopSymbolIds.end();
58 mValid = (symbol->getQualifier() == EvqConst) || isLoopSymbol;
59 }
60 }
61
62 private:
63 bool mValid;
64 const std::vector<int> mLoopSymbolIds;
65};
66
67// Traverses intermediate tree to ensure that the shader does not exceed the
68// minimum functionality mandated in GLSL 1.0 spec, Appendix A.
69class ValidateLimitationsTraverser : public TLValueTrackingTraverser
70{
71 public:
72 ValidateLimitationsTraverser(sh::GLenum shaderType,
73 TSymbolTable *symbolTable,
74 TDiagnostics *diagnostics);
75
76 void visitSymbol(TIntermSymbol *node) override;
77 bool visitBinary(Visit, TIntermBinary *) override;
78 bool visitLoop(Visit, TIntermLoop *) override;
79
80 private:
81 void error(TSourceLoc loc, const char *reason, const char *token);
82 void error(TSourceLoc loc, const char *reason, const ImmutableString &token);
83
84 bool isLoopIndex(TIntermSymbol *symbol);
85 bool validateLoopType(TIntermLoop *node);
86
87 bool validateForLoopHeader(TIntermLoop *node);
88 // If valid, return the index symbol id; Otherwise, return -1.
89 int validateForLoopInit(TIntermLoop *node);
90 bool validateForLoopCond(TIntermLoop *node, int indexSymbolId);
91 bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId);
92
93 // Returns true if indexing does not exceed the minimum functionality
94 // mandated in GLSL 1.0 spec, Appendix A, Section 5.
95 bool isConstExpr(TIntermNode *node);
96 bool isConstIndexExpr(TIntermNode *node);
97 bool validateIndexing(TIntermBinary *node);
98
99 sh::GLenum mShaderType;
100 TDiagnostics *mDiagnostics;
101 std::vector<int> mLoopSymbolIds;
102};
103
104ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType,
105 TSymbolTable *symbolTable,
106 TDiagnostics *diagnostics)
107 : TLValueTrackingTraverser(true, false, false, symbolTable),
108 mShaderType(shaderType),
109 mDiagnostics(diagnostics)
110{
111 ASSERT(diagnostics);
112}
113
114void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node)
115{
116 if (isLoopIndex(node) && isLValueRequiredHere())
117 {
118 error(node->getLine(),
119 "Loop index cannot be statically assigned to within the body of the loop",
120 node->getName());
121 }
122}
123
124bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node)
125{
126 // Check indexing.
127 switch (node->getOp())
128 {
129 case EOpIndexDirect:
130 case EOpIndexIndirect:
131 validateIndexing(node);
132 break;
133 default:
134 break;
135 }
136 return true;
137}
138
139bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node)
140{
141 if (!validateLoopType(node))
142 return false;
143
144 if (!validateForLoopHeader(node))
145 return false;
146
147 TIntermNode *body = node->getBody();
148 if (body != nullptr)
149 {
150 mLoopSymbolIds.push_back(GetLoopSymbolId(node));
151 body->traverse(this);
152 mLoopSymbolIds.pop_back();
153 }
154
155 // The loop is fully processed - no need to visit children.
156 return false;
157}
158
159void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token)
160{
161 mDiagnostics->error(loc, reason, token);
162}
163
164void ValidateLimitationsTraverser::error(TSourceLoc loc,
165 const char *reason,
166 const ImmutableString &token)
167{
168 error(loc, reason, token.data());
169}
170
171bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol)
172{
173 return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->uniqueId().get()) !=
174 mLoopSymbolIds.end();
175}
176
177bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node)
178{
179 TLoopType type = node->getType();
180 if (type == ELoopFor)
181 return true;
182
183 // Reject while and do-while loops.
184 error(node->getLine(), "This type of loop is not allowed", type == ELoopWhile ? "while" : "do");
185 return false;
186}
187
188bool ValidateLimitationsTraverser::validateForLoopHeader(TIntermLoop *node)
189{
190 ASSERT(node->getType() == ELoopFor);
191
192 //
193 // The for statement has the form:
194 // for ( init-declaration ; condition ; expression ) statement
195 //
196 int indexSymbolId = validateForLoopInit(node);
197 if (indexSymbolId < 0)
198 return false;
199 if (!validateForLoopCond(node, indexSymbolId))
200 return false;
201 if (!validateForLoopExpr(node, indexSymbolId))
202 return false;
203
204 return true;
205}
206
207int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node)
208{
209 TIntermNode *init = node->getInit();
210 if (init == nullptr)
211 {
212 error(node->getLine(), "Missing init declaration", "for");
213 return -1;
214 }
215
216 //
217 // init-declaration has the form:
218 // type-specifier identifier = constant-expression
219 //
220 TIntermDeclaration *decl = init->getAsDeclarationNode();
221 if (decl == nullptr)
222 {
223 error(init->getLine(), "Invalid init declaration", "for");
224 return -1;
225 }
226 // To keep things simple do not allow declaration list.
227 TIntermSequence *declSeq = decl->getSequence();
228 if (declSeq->size() != 1)
229 {
230 error(decl->getLine(), "Invalid init declaration", "for");
231 return -1;
232 }
233 TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
234 if ((declInit == nullptr) || (declInit->getOp() != EOpInitialize))
235 {
236 error(decl->getLine(), "Invalid init declaration", "for");
237 return -1;
238 }
239 TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
240 if (symbol == nullptr)
241 {
242 error(declInit->getLine(), "Invalid init declaration", "for");
243 return -1;
244 }
245 // The loop index has type int or float.
246 TBasicType type = symbol->getBasicType();
247 if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat))
248 {
249 error(symbol->getLine(), "Invalid type for loop index", getBasicString(type));
250 return -1;
251 }
252 // The loop index is initialized with constant expression.
253 if (!isConstExpr(declInit->getRight()))
254 {
255 error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression",
256 symbol->getName());
257 return -1;
258 }
259
260 return symbol->uniqueId().get();
261}
262
263bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
264{
265 TIntermNode *cond = node->getCondition();
266 if (cond == nullptr)
267 {
268 error(node->getLine(), "Missing condition", "for");
269 return false;
270 }
271 //
272 // condition has the form:
273 // loop_index relational_operator constant_expression
274 //
275 TIntermBinary *binOp = cond->getAsBinaryNode();
276 if (binOp == nullptr)
277 {
278 error(node->getLine(), "Invalid condition", "for");
279 return false;
280 }
281 // Loop index should be to the left of relational operator.
282 TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
283 if (symbol == nullptr)
284 {
285 error(binOp->getLine(), "Invalid condition", "for");
286 return false;
287 }
288 if (symbol->uniqueId().get() != indexSymbolId)
289 {
290 error(symbol->getLine(), "Expected loop index", symbol->getName());
291 return false;
292 }
293 // Relational operator is one of: > >= < <= == or !=.
294 switch (binOp->getOp())
295 {
296 case EOpEqual:
297 case EOpNotEqual:
298 case EOpLessThan:
299 case EOpGreaterThan:
300 case EOpLessThanEqual:
301 case EOpGreaterThanEqual:
302 break;
303 default:
304 error(binOp->getLine(), "Invalid relational operator",
305 GetOperatorString(binOp->getOp()));
306 break;
307 }
308 // Loop index must be compared with a constant.
309 if (!isConstExpr(binOp->getRight()))
310 {
311 error(binOp->getLine(), "Loop index cannot be compared with non-constant expression",
312 symbol->getName());
313 return false;
314 }
315
316 return true;
317}
318
319bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId)
320{
321 TIntermNode *expr = node->getExpression();
322 if (expr == nullptr)
323 {
324 error(node->getLine(), "Missing expression", "for");
325 return false;
326 }
327
328 // for expression has one of the following forms:
329 // loop_index++
330 // loop_index--
331 // loop_index += constant_expression
332 // loop_index -= constant_expression
333 // ++loop_index
334 // --loop_index
335 // The last two forms are not specified in the spec, but I am assuming
336 // its an oversight.
337 TIntermUnary *unOp = expr->getAsUnaryNode();
338 TIntermBinary *binOp = unOp ? nullptr : expr->getAsBinaryNode();
339
340 TOperator op = EOpNull;
341 TIntermSymbol *symbol = nullptr;
342 if (unOp != nullptr)
343 {
344 op = unOp->getOp();
345 symbol = unOp->getOperand()->getAsSymbolNode();
346 }
347 else if (binOp != nullptr)
348 {
349 op = binOp->getOp();
350 symbol = binOp->getLeft()->getAsSymbolNode();
351 }
352
353 // The operand must be loop index.
354 if (symbol == nullptr)
355 {
356 error(expr->getLine(), "Invalid expression", "for");
357 return false;
358 }
359 if (symbol->uniqueId().get() != indexSymbolId)
360 {
361 error(symbol->getLine(), "Expected loop index", symbol->getName());
362 return false;
363 }
364
365 // The operator is one of: ++ -- += -=.
366 switch (op)
367 {
368 case EOpPostIncrement:
369 case EOpPostDecrement:
370 case EOpPreIncrement:
371 case EOpPreDecrement:
372 ASSERT((unOp != nullptr) && (binOp == nullptr));
373 break;
374 case EOpAddAssign:
375 case EOpSubAssign:
376 ASSERT((unOp == nullptr) && (binOp != nullptr));
377 break;
378 default:
379 error(expr->getLine(), "Invalid operator", GetOperatorString(op));
380 return false;
381 }
382
383 // Loop index must be incremented/decremented with a constant.
384 if (binOp != nullptr)
385 {
386 if (!isConstExpr(binOp->getRight()))
387 {
388 error(binOp->getLine(), "Loop index cannot be modified by non-constant expression",
389 symbol->getName());
390 return false;
391 }
392 }
393
394 return true;
395}
396
397bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node)
398{
399 ASSERT(node != nullptr);
400 return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst;
401}
402
403bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node)
404{
405 ASSERT(node != nullptr);
406
407 ValidateConstIndexExpr validate(mLoopSymbolIds);
408 node->traverse(&validate);
409 return validate.isValid();
410}
411
412bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node)
413{
414 ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect));
415
416 bool valid = true;
417 TIntermTyped *index = node->getRight();
418 // The index expession must be a constant-index-expression unless
419 // the operand is a uniform in a vertex shader.
420 TIntermTyped *operand = node->getLeft();
421 bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform);
422 if (!skip && !isConstIndexExpr(index))
423 {
424 error(index->getLine(), "Index expression must be constant", "[]");
425 valid = false;
426 }
427 return valid;
428}
429
430} // namespace
431
432bool ValidateLimitations(TIntermNode *root,
433 GLenum shaderType,
434 TSymbolTable *symbolTable,
435 TDiagnostics *diagnostics)
436{
437 ValidateLimitationsTraverser validate(shaderType, symbolTable, diagnostics);
438 root->traverse(&validate);
439 return diagnostics->numErrors() == 0;
440}
441
442} // namespace sh
443