1 | // |
2 | // Copyright 2015 The ANGLE Project Authors. All rights reserved. |
3 | // Use of this source code is governed by a BSD-style license that can be |
4 | // found in the LICENSE file. |
5 | // |
6 | // Matrix: |
7 | // Utility class implementing various matrix operations. |
8 | // Supports matrices with minimum 2 and maximum 4 number of rows/columns. |
9 | // |
10 | // TODO: Check if we can merge Matrix.h in sample_util with this and replace it with this |
11 | // implementation. |
12 | // TODO: Rename this file to Matrix.h once we remove Matrix.h in sample_util. |
13 | |
14 | #ifndef COMMON_MATRIX_UTILS_H_ |
15 | #define COMMON_MATRIX_UTILS_H_ |
16 | |
17 | #include <vector> |
18 | |
19 | #include "common/debug.h" |
20 | #include "common/mathutil.h" |
21 | #include "common/vector_utils.h" |
22 | |
23 | namespace angle |
24 | { |
25 | |
26 | template <typename T> |
27 | class Matrix |
28 | { |
29 | public: |
30 | Matrix(const std::vector<T> &elements, const unsigned int numRows, const unsigned int numCols) |
31 | : mElements(elements), mRows(numRows), mCols(numCols) |
32 | { |
33 | ASSERT(rows() >= 1 && rows() <= 4); |
34 | ASSERT(columns() >= 1 && columns() <= 4); |
35 | } |
36 | |
37 | Matrix(const std::vector<T> &elements, const unsigned int size) |
38 | : mElements(elements), mRows(size), mCols(size) |
39 | { |
40 | ASSERT(rows() >= 1 && rows() <= 4); |
41 | ASSERT(columns() >= 1 && columns() <= 4); |
42 | } |
43 | |
44 | Matrix(const T *elements, const unsigned int size) : mRows(size), mCols(size) |
45 | { |
46 | ASSERT(rows() >= 1 && rows() <= 4); |
47 | ASSERT(columns() >= 1 && columns() <= 4); |
48 | for (size_t i = 0; i < size * size; i++) |
49 | mElements.push_back(elements[i]); |
50 | } |
51 | |
52 | const T &operator()(const unsigned int rowIndex, const unsigned int columnIndex) const |
53 | { |
54 | ASSERT(rowIndex < mRows); |
55 | ASSERT(columnIndex < mCols); |
56 | return mElements[rowIndex * columns() + columnIndex]; |
57 | } |
58 | |
59 | T &operator()(const unsigned int rowIndex, const unsigned int columnIndex) |
60 | { |
61 | ASSERT(rowIndex < mRows); |
62 | ASSERT(columnIndex < mCols); |
63 | return mElements[rowIndex * columns() + columnIndex]; |
64 | } |
65 | |
66 | const T &at(const unsigned int rowIndex, const unsigned int columnIndex) const |
67 | { |
68 | ASSERT(rowIndex < mRows); |
69 | ASSERT(columnIndex < mCols); |
70 | return operator()(rowIndex, columnIndex); |
71 | } |
72 | |
73 | Matrix<T> operator*(const Matrix<T> &m) |
74 | { |
75 | ASSERT(columns() == m.rows()); |
76 | |
77 | unsigned int resultRows = rows(); |
78 | unsigned int resultCols = m.columns(); |
79 | Matrix<T> result(std::vector<T>(resultRows * resultCols), resultRows, resultCols); |
80 | for (unsigned int i = 0; i < resultRows; i++) |
81 | { |
82 | for (unsigned int j = 0; j < resultCols; j++) |
83 | { |
84 | T tmp = 0.0f; |
85 | for (unsigned int k = 0; k < columns(); k++) |
86 | tmp += at(i, k) * m(k, j); |
87 | result(i, j) = tmp; |
88 | } |
89 | } |
90 | |
91 | return result; |
92 | } |
93 | |
94 | void operator*=(const Matrix<T> &m) |
95 | { |
96 | ASSERT(columns() == m.rows()); |
97 | Matrix<T> res = (*this) * m; |
98 | size_t numElts = res.elements().size(); |
99 | mElements.resize(numElts); |
100 | memcpy(mElements.data(), res.data(), numElts * sizeof(float)); |
101 | } |
102 | |
103 | bool operator==(const Matrix<T> &m) const |
104 | { |
105 | ASSERT(columns() == m.columns()); |
106 | ASSERT(rows() == m.rows()); |
107 | return mElements == m.elements(); |
108 | } |
109 | |
110 | bool operator!=(const Matrix<T> &m) const { return !(mElements == m.elements()); } |
111 | |
112 | bool nearlyEqual(T epsilon, const Matrix<T> &m) const |
113 | { |
114 | ASSERT(columns() == m.columns()); |
115 | ASSERT(rows() == m.rows()); |
116 | const auto &otherElts = m.elements(); |
117 | for (size_t i = 0; i < otherElts.size(); i++) |
118 | { |
119 | if ((mElements[i] - otherElts[i] > epsilon) && (otherElts[i] - mElements[i] > epsilon)) |
120 | return false; |
121 | } |
122 | return true; |
123 | } |
124 | |
125 | unsigned int size() const |
126 | { |
127 | ASSERT(rows() == columns()); |
128 | return rows(); |
129 | } |
130 | |
131 | unsigned int rows() const { return mRows; } |
132 | |
133 | unsigned int columns() const { return mCols; } |
134 | |
135 | std::vector<T> elements() const { return mElements; } |
136 | T *data() { return mElements.data(); } |
137 | |
138 | Matrix<T> compMult(const Matrix<T> &mat1) const |
139 | { |
140 | Matrix result(std::vector<T>(mElements.size()), rows(), columns()); |
141 | for (unsigned int i = 0; i < rows(); i++) |
142 | { |
143 | for (unsigned int j = 0; j < columns(); j++) |
144 | { |
145 | T lhs = at(i, j); |
146 | T rhs = mat1(i, j); |
147 | result(i, j) = rhs * lhs; |
148 | } |
149 | } |
150 | |
151 | return result; |
152 | } |
153 | |
154 | Matrix<T> outerProduct(const Matrix<T> &mat1) const |
155 | { |
156 | unsigned int cols = mat1.columns(); |
157 | Matrix result(std::vector<T>(rows() * cols), rows(), cols); |
158 | for (unsigned int i = 0; i < rows(); i++) |
159 | for (unsigned int j = 0; j < cols; j++) |
160 | result(i, j) = at(i, 0) * mat1(0, j); |
161 | |
162 | return result; |
163 | } |
164 | |
165 | Matrix<T> transpose() const |
166 | { |
167 | Matrix result(std::vector<T>(mElements.size()), columns(), rows()); |
168 | for (unsigned int i = 0; i < columns(); i++) |
169 | for (unsigned int j = 0; j < rows(); j++) |
170 | result(i, j) = at(j, i); |
171 | |
172 | return result; |
173 | } |
174 | |
175 | T determinant() const |
176 | { |
177 | ASSERT(rows() == columns()); |
178 | |
179 | switch (size()) |
180 | { |
181 | case 2: |
182 | return at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0); |
183 | |
184 | case 3: |
185 | return at(0, 0) * at(1, 1) * at(2, 2) + at(0, 1) * at(1, 2) * at(2, 0) + |
186 | at(0, 2) * at(1, 0) * at(2, 1) - at(0, 2) * at(1, 1) * at(2, 0) - |
187 | at(0, 1) * at(1, 0) * at(2, 2) - at(0, 0) * at(1, 2) * at(2, 1); |
188 | |
189 | case 4: |
190 | { |
191 | const float minorMatrices[4][3 * 3] = {{ |
192 | at(1, 1), |
193 | at(2, 1), |
194 | at(3, 1), |
195 | at(1, 2), |
196 | at(2, 2), |
197 | at(3, 2), |
198 | at(1, 3), |
199 | at(2, 3), |
200 | at(3, 3), |
201 | }, |
202 | { |
203 | at(1, 0), |
204 | at(2, 0), |
205 | at(3, 0), |
206 | at(1, 2), |
207 | at(2, 2), |
208 | at(3, 2), |
209 | at(1, 3), |
210 | at(2, 3), |
211 | at(3, 3), |
212 | }, |
213 | { |
214 | at(1, 0), |
215 | at(2, 0), |
216 | at(3, 0), |
217 | at(1, 1), |
218 | at(2, 1), |
219 | at(3, 1), |
220 | at(1, 3), |
221 | at(2, 3), |
222 | at(3, 3), |
223 | }, |
224 | { |
225 | at(1, 0), |
226 | at(2, 0), |
227 | at(3, 0), |
228 | at(1, 1), |
229 | at(2, 1), |
230 | at(3, 1), |
231 | at(1, 2), |
232 | at(2, 2), |
233 | at(3, 2), |
234 | }}; |
235 | return at(0, 0) * Matrix<T>(minorMatrices[0], 3).determinant() - |
236 | at(0, 1) * Matrix<T>(minorMatrices[1], 3).determinant() + |
237 | at(0, 2) * Matrix<T>(minorMatrices[2], 3).determinant() - |
238 | at(0, 3) * Matrix<T>(minorMatrices[3], 3).determinant(); |
239 | } |
240 | |
241 | default: |
242 | UNREACHABLE(); |
243 | break; |
244 | } |
245 | |
246 | return T(); |
247 | } |
248 | |
249 | Matrix<T> inverse() const |
250 | { |
251 | ASSERT(rows() == columns()); |
252 | |
253 | Matrix<T> cof(std::vector<T>(mElements.size()), rows(), columns()); |
254 | switch (size()) |
255 | { |
256 | case 2: |
257 | cof(0, 0) = at(1, 1); |
258 | cof(0, 1) = -at(1, 0); |
259 | cof(1, 0) = -at(0, 1); |
260 | cof(1, 1) = at(0, 0); |
261 | break; |
262 | |
263 | case 3: |
264 | cof(0, 0) = at(1, 1) * at(2, 2) - at(2, 1) * at(1, 2); |
265 | cof(0, 1) = -(at(1, 0) * at(2, 2) - at(2, 0) * at(1, 2)); |
266 | cof(0, 2) = at(1, 0) * at(2, 1) - at(2, 0) * at(1, 1); |
267 | cof(1, 0) = -(at(0, 1) * at(2, 2) - at(2, 1) * at(0, 2)); |
268 | cof(1, 1) = at(0, 0) * at(2, 2) - at(2, 0) * at(0, 2); |
269 | cof(1, 2) = -(at(0, 0) * at(2, 1) - at(2, 0) * at(0, 1)); |
270 | cof(2, 0) = at(0, 1) * at(1, 2) - at(1, 1) * at(0, 2); |
271 | cof(2, 1) = -(at(0, 0) * at(1, 2) - at(1, 0) * at(0, 2)); |
272 | cof(2, 2) = at(0, 0) * at(1, 1) - at(1, 0) * at(0, 1); |
273 | break; |
274 | |
275 | case 4: |
276 | cof(0, 0) = at(1, 1) * at(2, 2) * at(3, 3) + at(2, 1) * at(3, 2) * at(1, 3) + |
277 | at(3, 1) * at(1, 2) * at(2, 3) - at(1, 1) * at(3, 2) * at(2, 3) - |
278 | at(2, 1) * at(1, 2) * at(3, 3) - at(3, 1) * at(2, 2) * at(1, 3); |
279 | cof(0, 1) = -(at(1, 0) * at(2, 2) * at(3, 3) + at(2, 0) * at(3, 2) * at(1, 3) + |
280 | at(3, 0) * at(1, 2) * at(2, 3) - at(1, 0) * at(3, 2) * at(2, 3) - |
281 | at(2, 0) * at(1, 2) * at(3, 3) - at(3, 0) * at(2, 2) * at(1, 3)); |
282 | cof(0, 2) = at(1, 0) * at(2, 1) * at(3, 3) + at(2, 0) * at(3, 1) * at(1, 3) + |
283 | at(3, 0) * at(1, 1) * at(2, 3) - at(1, 0) * at(3, 1) * at(2, 3) - |
284 | at(2, 0) * at(1, 1) * at(3, 3) - at(3, 0) * at(2, 1) * at(1, 3); |
285 | cof(0, 3) = -(at(1, 0) * at(2, 1) * at(3, 2) + at(2, 0) * at(3, 1) * at(1, 2) + |
286 | at(3, 0) * at(1, 1) * at(2, 2) - at(1, 0) * at(3, 1) * at(2, 2) - |
287 | at(2, 0) * at(1, 1) * at(3, 2) - at(3, 0) * at(2, 1) * at(1, 2)); |
288 | cof(1, 0) = -(at(0, 1) * at(2, 2) * at(3, 3) + at(2, 1) * at(3, 2) * at(0, 3) + |
289 | at(3, 1) * at(0, 2) * at(2, 3) - at(0, 1) * at(3, 2) * at(2, 3) - |
290 | at(2, 1) * at(0, 2) * at(3, 3) - at(3, 1) * at(2, 2) * at(0, 3)); |
291 | cof(1, 1) = at(0, 0) * at(2, 2) * at(3, 3) + at(2, 0) * at(3, 2) * at(0, 3) + |
292 | at(3, 0) * at(0, 2) * at(2, 3) - at(0, 0) * at(3, 2) * at(2, 3) - |
293 | at(2, 0) * at(0, 2) * at(3, 3) - at(3, 0) * at(2, 2) * at(0, 3); |
294 | cof(1, 2) = -(at(0, 0) * at(2, 1) * at(3, 3) + at(2, 0) * at(3, 1) * at(0, 3) + |
295 | at(3, 0) * at(0, 1) * at(2, 3) - at(0, 0) * at(3, 1) * at(2, 3) - |
296 | at(2, 0) * at(0, 1) * at(3, 3) - at(3, 0) * at(2, 1) * at(0, 3)); |
297 | cof(1, 3) = at(0, 0) * at(2, 1) * at(3, 2) + at(2, 0) * at(3, 1) * at(0, 2) + |
298 | at(3, 0) * at(0, 1) * at(2, 2) - at(0, 0) * at(3, 1) * at(2, 2) - |
299 | at(2, 0) * at(0, 1) * at(3, 2) - at(3, 0) * at(2, 1) * at(0, 2); |
300 | cof(2, 0) = at(0, 1) * at(1, 2) * at(3, 3) + at(1, 1) * at(3, 2) * at(0, 3) + |
301 | at(3, 1) * at(0, 2) * at(1, 3) - at(0, 1) * at(3, 2) * at(1, 3) - |
302 | at(1, 1) * at(0, 2) * at(3, 3) - at(3, 1) * at(1, 2) * at(0, 3); |
303 | cof(2, 1) = -(at(0, 0) * at(1, 2) * at(3, 3) + at(1, 0) * at(3, 2) * at(0, 3) + |
304 | at(3, 0) * at(0, 2) * at(1, 3) - at(0, 0) * at(3, 2) * at(1, 3) - |
305 | at(1, 0) * at(0, 2) * at(3, 3) - at(3, 0) * at(1, 2) * at(0, 3)); |
306 | cof(2, 2) = at(0, 0) * at(1, 1) * at(3, 3) + at(1, 0) * at(3, 1) * at(0, 3) + |
307 | at(3, 0) * at(0, 1) * at(1, 3) - at(0, 0) * at(3, 1) * at(1, 3) - |
308 | at(1, 0) * at(0, 1) * at(3, 3) - at(3, 0) * at(1, 1) * at(0, 3); |
309 | cof(2, 3) = -(at(0, 0) * at(1, 1) * at(3, 2) + at(1, 0) * at(3, 1) * at(0, 2) + |
310 | at(3, 0) * at(0, 1) * at(1, 2) - at(0, 0) * at(3, 1) * at(1, 2) - |
311 | at(1, 0) * at(0, 1) * at(3, 2) - at(3, 0) * at(1, 1) * at(0, 2)); |
312 | cof(3, 0) = -(at(0, 1) * at(1, 2) * at(2, 3) + at(1, 1) * at(2, 2) * at(0, 3) + |
313 | at(2, 1) * at(0, 2) * at(1, 3) - at(0, 1) * at(2, 2) * at(1, 3) - |
314 | at(1, 1) * at(0, 2) * at(2, 3) - at(2, 1) * at(1, 2) * at(0, 3)); |
315 | cof(3, 1) = at(0, 0) * at(1, 2) * at(2, 3) + at(1, 0) * at(2, 2) * at(0, 3) + |
316 | at(2, 0) * at(0, 2) * at(1, 3) - at(0, 0) * at(2, 2) * at(1, 3) - |
317 | at(1, 0) * at(0, 2) * at(2, 3) - at(2, 0) * at(1, 2) * at(0, 3); |
318 | cof(3, 2) = -(at(0, 0) * at(1, 1) * at(2, 3) + at(1, 0) * at(2, 1) * at(0, 3) + |
319 | at(2, 0) * at(0, 1) * at(1, 3) - at(0, 0) * at(2, 1) * at(1, 3) - |
320 | at(1, 0) * at(0, 1) * at(2, 3) - at(2, 0) * at(1, 1) * at(0, 3)); |
321 | cof(3, 3) = at(0, 0) * at(1, 1) * at(2, 2) + at(1, 0) * at(2, 1) * at(0, 2) + |
322 | at(2, 0) * at(0, 1) * at(1, 2) - at(0, 0) * at(2, 1) * at(1, 2) - |
323 | at(1, 0) * at(0, 1) * at(2, 2) - at(2, 0) * at(1, 1) * at(0, 2); |
324 | break; |
325 | |
326 | default: |
327 | UNREACHABLE(); |
328 | break; |
329 | } |
330 | |
331 | // The inverse of A is the transpose of the cofactor matrix times the reciprocal of the |
332 | // determinant of A. |
333 | Matrix<T> adjugateMatrix(cof.transpose()); |
334 | T det = determinant(); |
335 | Matrix<T> result(std::vector<T>(mElements.size()), rows(), columns()); |
336 | for (unsigned int i = 0; i < rows(); i++) |
337 | for (unsigned int j = 0; j < columns(); j++) |
338 | result(i, j) = det ? adjugateMatrix(i, j) / det : T(); |
339 | |
340 | return result; |
341 | } |
342 | |
343 | void setToIdentity() |
344 | { |
345 | ASSERT(rows() == columns()); |
346 | |
347 | const auto one = T(1); |
348 | const auto zero = T(0); |
349 | |
350 | for (auto &e : mElements) |
351 | e = zero; |
352 | |
353 | for (unsigned int i = 0; i < rows(); ++i) |
354 | { |
355 | const auto pos = i * columns() + (i % columns()); |
356 | mElements[pos] = one; |
357 | } |
358 | } |
359 | |
360 | template <unsigned int Size> |
361 | static void setToIdentity(T (&matrix)[Size]) |
362 | { |
363 | static_assert(gl::iSquareRoot<Size>() != 0, "Matrix is not square." ); |
364 | |
365 | const auto cols = gl::iSquareRoot<Size>(); |
366 | const auto one = T(1); |
367 | const auto zero = T(0); |
368 | |
369 | for (auto &e : matrix) |
370 | e = zero; |
371 | |
372 | for (unsigned int i = 0; i < cols; ++i) |
373 | { |
374 | const auto pos = i * cols + (i % cols); |
375 | matrix[pos] = one; |
376 | } |
377 | } |
378 | |
379 | protected: |
380 | std::vector<T> mElements; |
381 | unsigned int mRows; |
382 | unsigned int mCols; |
383 | }; |
384 | |
385 | class Mat4 : public Matrix<float> |
386 | { |
387 | public: |
388 | Mat4(); |
389 | Mat4(const Matrix<float> generalMatrix); |
390 | Mat4(const std::vector<float> &elements); |
391 | Mat4(const float *elements); |
392 | Mat4(float m00, |
393 | float m01, |
394 | float m02, |
395 | float m03, |
396 | float m10, |
397 | float m11, |
398 | float m12, |
399 | float m13, |
400 | float m20, |
401 | float m21, |
402 | float m22, |
403 | float m23, |
404 | float m30, |
405 | float m31, |
406 | float m32, |
407 | float m33); |
408 | |
409 | static Mat4 Rotate(float angle, const Vector3 &axis); |
410 | static Mat4 Translate(const Vector3 &t); |
411 | static Mat4 Scale(const Vector3 &s); |
412 | static Mat4 Frustum(float l, float r, float b, float t, float n, float f); |
413 | static Mat4 Perspective(float fov, float aspectRatio, float n, float f); |
414 | static Mat4 Ortho(float l, float r, float b, float t, float n, float f); |
415 | |
416 | Mat4 product(const Mat4 &m); |
417 | Vector4 product(const Vector4 &b); |
418 | void dump(); |
419 | }; |
420 | |
421 | } // namespace angle |
422 | |
423 | #endif // COMMON_MATRIX_UTILS_H_ |
424 | |