1//
2// Copyright 2015 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// Matrix:
7// Utility class implementing various matrix operations.
8// Supports matrices with minimum 2 and maximum 4 number of rows/columns.
9//
10// TODO: Check if we can merge Matrix.h in sample_util with this and replace it with this
11// implementation.
12// TODO: Rename this file to Matrix.h once we remove Matrix.h in sample_util.
13
14#ifndef COMMON_MATRIX_UTILS_H_
15#define COMMON_MATRIX_UTILS_H_
16
17#include <vector>
18
19#include "common/debug.h"
20#include "common/mathutil.h"
21#include "common/vector_utils.h"
22
23namespace angle
24{
25
26template <typename T>
27class Matrix
28{
29 public:
30 Matrix(const std::vector<T> &elements, const unsigned int numRows, const unsigned int numCols)
31 : mElements(elements), mRows(numRows), mCols(numCols)
32 {
33 ASSERT(rows() >= 1 && rows() <= 4);
34 ASSERT(columns() >= 1 && columns() <= 4);
35 }
36
37 Matrix(const std::vector<T> &elements, const unsigned int size)
38 : mElements(elements), mRows(size), mCols(size)
39 {
40 ASSERT(rows() >= 1 && rows() <= 4);
41 ASSERT(columns() >= 1 && columns() <= 4);
42 }
43
44 Matrix(const T *elements, const unsigned int size) : mRows(size), mCols(size)
45 {
46 ASSERT(rows() >= 1 && rows() <= 4);
47 ASSERT(columns() >= 1 && columns() <= 4);
48 for (size_t i = 0; i < size * size; i++)
49 mElements.push_back(elements[i]);
50 }
51
52 const T &operator()(const unsigned int rowIndex, const unsigned int columnIndex) const
53 {
54 ASSERT(rowIndex < mRows);
55 ASSERT(columnIndex < mCols);
56 return mElements[rowIndex * columns() + columnIndex];
57 }
58
59 T &operator()(const unsigned int rowIndex, const unsigned int columnIndex)
60 {
61 ASSERT(rowIndex < mRows);
62 ASSERT(columnIndex < mCols);
63 return mElements[rowIndex * columns() + columnIndex];
64 }
65
66 const T &at(const unsigned int rowIndex, const unsigned int columnIndex) const
67 {
68 ASSERT(rowIndex < mRows);
69 ASSERT(columnIndex < mCols);
70 return operator()(rowIndex, columnIndex);
71 }
72
73 Matrix<T> operator*(const Matrix<T> &m)
74 {
75 ASSERT(columns() == m.rows());
76
77 unsigned int resultRows = rows();
78 unsigned int resultCols = m.columns();
79 Matrix<T> result(std::vector<T>(resultRows * resultCols), resultRows, resultCols);
80 for (unsigned int i = 0; i < resultRows; i++)
81 {
82 for (unsigned int j = 0; j < resultCols; j++)
83 {
84 T tmp = 0.0f;
85 for (unsigned int k = 0; k < columns(); k++)
86 tmp += at(i, k) * m(k, j);
87 result(i, j) = tmp;
88 }
89 }
90
91 return result;
92 }
93
94 void operator*=(const Matrix<T> &m)
95 {
96 ASSERT(columns() == m.rows());
97 Matrix<T> res = (*this) * m;
98 size_t numElts = res.elements().size();
99 mElements.resize(numElts);
100 memcpy(mElements.data(), res.data(), numElts * sizeof(float));
101 }
102
103 bool operator==(const Matrix<T> &m) const
104 {
105 ASSERT(columns() == m.columns());
106 ASSERT(rows() == m.rows());
107 return mElements == m.elements();
108 }
109
110 bool operator!=(const Matrix<T> &m) const { return !(mElements == m.elements()); }
111
112 bool nearlyEqual(T epsilon, const Matrix<T> &m) const
113 {
114 ASSERT(columns() == m.columns());
115 ASSERT(rows() == m.rows());
116 const auto &otherElts = m.elements();
117 for (size_t i = 0; i < otherElts.size(); i++)
118 {
119 if ((mElements[i] - otherElts[i] > epsilon) && (otherElts[i] - mElements[i] > epsilon))
120 return false;
121 }
122 return true;
123 }
124
125 unsigned int size() const
126 {
127 ASSERT(rows() == columns());
128 return rows();
129 }
130
131 unsigned int rows() const { return mRows; }
132
133 unsigned int columns() const { return mCols; }
134
135 std::vector<T> elements() const { return mElements; }
136 T *data() { return mElements.data(); }
137
138 Matrix<T> compMult(const Matrix<T> &mat1) const
139 {
140 Matrix result(std::vector<T>(mElements.size()), rows(), columns());
141 for (unsigned int i = 0; i < rows(); i++)
142 {
143 for (unsigned int j = 0; j < columns(); j++)
144 {
145 T lhs = at(i, j);
146 T rhs = mat1(i, j);
147 result(i, j) = rhs * lhs;
148 }
149 }
150
151 return result;
152 }
153
154 Matrix<T> outerProduct(const Matrix<T> &mat1) const
155 {
156 unsigned int cols = mat1.columns();
157 Matrix result(std::vector<T>(rows() * cols), rows(), cols);
158 for (unsigned int i = 0; i < rows(); i++)
159 for (unsigned int j = 0; j < cols; j++)
160 result(i, j) = at(i, 0) * mat1(0, j);
161
162 return result;
163 }
164
165 Matrix<T> transpose() const
166 {
167 Matrix result(std::vector<T>(mElements.size()), columns(), rows());
168 for (unsigned int i = 0; i < columns(); i++)
169 for (unsigned int j = 0; j < rows(); j++)
170 result(i, j) = at(j, i);
171
172 return result;
173 }
174
175 T determinant() const
176 {
177 ASSERT(rows() == columns());
178
179 switch (size())
180 {
181 case 2:
182 return at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0);
183
184 case 3:
185 return at(0, 0) * at(1, 1) * at(2, 2) + at(0, 1) * at(1, 2) * at(2, 0) +
186 at(0, 2) * at(1, 0) * at(2, 1) - at(0, 2) * at(1, 1) * at(2, 0) -
187 at(0, 1) * at(1, 0) * at(2, 2) - at(0, 0) * at(1, 2) * at(2, 1);
188
189 case 4:
190 {
191 const float minorMatrices[4][3 * 3] = {{
192 at(1, 1),
193 at(2, 1),
194 at(3, 1),
195 at(1, 2),
196 at(2, 2),
197 at(3, 2),
198 at(1, 3),
199 at(2, 3),
200 at(3, 3),
201 },
202 {
203 at(1, 0),
204 at(2, 0),
205 at(3, 0),
206 at(1, 2),
207 at(2, 2),
208 at(3, 2),
209 at(1, 3),
210 at(2, 3),
211 at(3, 3),
212 },
213 {
214 at(1, 0),
215 at(2, 0),
216 at(3, 0),
217 at(1, 1),
218 at(2, 1),
219 at(3, 1),
220 at(1, 3),
221 at(2, 3),
222 at(3, 3),
223 },
224 {
225 at(1, 0),
226 at(2, 0),
227 at(3, 0),
228 at(1, 1),
229 at(2, 1),
230 at(3, 1),
231 at(1, 2),
232 at(2, 2),
233 at(3, 2),
234 }};
235 return at(0, 0) * Matrix<T>(minorMatrices[0], 3).determinant() -
236 at(0, 1) * Matrix<T>(minorMatrices[1], 3).determinant() +
237 at(0, 2) * Matrix<T>(minorMatrices[2], 3).determinant() -
238 at(0, 3) * Matrix<T>(minorMatrices[3], 3).determinant();
239 }
240
241 default:
242 UNREACHABLE();
243 break;
244 }
245
246 return T();
247 }
248
249 Matrix<T> inverse() const
250 {
251 ASSERT(rows() == columns());
252
253 Matrix<T> cof(std::vector<T>(mElements.size()), rows(), columns());
254 switch (size())
255 {
256 case 2:
257 cof(0, 0) = at(1, 1);
258 cof(0, 1) = -at(1, 0);
259 cof(1, 0) = -at(0, 1);
260 cof(1, 1) = at(0, 0);
261 break;
262
263 case 3:
264 cof(0, 0) = at(1, 1) * at(2, 2) - at(2, 1) * at(1, 2);
265 cof(0, 1) = -(at(1, 0) * at(2, 2) - at(2, 0) * at(1, 2));
266 cof(0, 2) = at(1, 0) * at(2, 1) - at(2, 0) * at(1, 1);
267 cof(1, 0) = -(at(0, 1) * at(2, 2) - at(2, 1) * at(0, 2));
268 cof(1, 1) = at(0, 0) * at(2, 2) - at(2, 0) * at(0, 2);
269 cof(1, 2) = -(at(0, 0) * at(2, 1) - at(2, 0) * at(0, 1));
270 cof(2, 0) = at(0, 1) * at(1, 2) - at(1, 1) * at(0, 2);
271 cof(2, 1) = -(at(0, 0) * at(1, 2) - at(1, 0) * at(0, 2));
272 cof(2, 2) = at(0, 0) * at(1, 1) - at(1, 0) * at(0, 1);
273 break;
274
275 case 4:
276 cof(0, 0) = at(1, 1) * at(2, 2) * at(3, 3) + at(2, 1) * at(3, 2) * at(1, 3) +
277 at(3, 1) * at(1, 2) * at(2, 3) - at(1, 1) * at(3, 2) * at(2, 3) -
278 at(2, 1) * at(1, 2) * at(3, 3) - at(3, 1) * at(2, 2) * at(1, 3);
279 cof(0, 1) = -(at(1, 0) * at(2, 2) * at(3, 3) + at(2, 0) * at(3, 2) * at(1, 3) +
280 at(3, 0) * at(1, 2) * at(2, 3) - at(1, 0) * at(3, 2) * at(2, 3) -
281 at(2, 0) * at(1, 2) * at(3, 3) - at(3, 0) * at(2, 2) * at(1, 3));
282 cof(0, 2) = at(1, 0) * at(2, 1) * at(3, 3) + at(2, 0) * at(3, 1) * at(1, 3) +
283 at(3, 0) * at(1, 1) * at(2, 3) - at(1, 0) * at(3, 1) * at(2, 3) -
284 at(2, 0) * at(1, 1) * at(3, 3) - at(3, 0) * at(2, 1) * at(1, 3);
285 cof(0, 3) = -(at(1, 0) * at(2, 1) * at(3, 2) + at(2, 0) * at(3, 1) * at(1, 2) +
286 at(3, 0) * at(1, 1) * at(2, 2) - at(1, 0) * at(3, 1) * at(2, 2) -
287 at(2, 0) * at(1, 1) * at(3, 2) - at(3, 0) * at(2, 1) * at(1, 2));
288 cof(1, 0) = -(at(0, 1) * at(2, 2) * at(3, 3) + at(2, 1) * at(3, 2) * at(0, 3) +
289 at(3, 1) * at(0, 2) * at(2, 3) - at(0, 1) * at(3, 2) * at(2, 3) -
290 at(2, 1) * at(0, 2) * at(3, 3) - at(3, 1) * at(2, 2) * at(0, 3));
291 cof(1, 1) = at(0, 0) * at(2, 2) * at(3, 3) + at(2, 0) * at(3, 2) * at(0, 3) +
292 at(3, 0) * at(0, 2) * at(2, 3) - at(0, 0) * at(3, 2) * at(2, 3) -
293 at(2, 0) * at(0, 2) * at(3, 3) - at(3, 0) * at(2, 2) * at(0, 3);
294 cof(1, 2) = -(at(0, 0) * at(2, 1) * at(3, 3) + at(2, 0) * at(3, 1) * at(0, 3) +
295 at(3, 0) * at(0, 1) * at(2, 3) - at(0, 0) * at(3, 1) * at(2, 3) -
296 at(2, 0) * at(0, 1) * at(3, 3) - at(3, 0) * at(2, 1) * at(0, 3));
297 cof(1, 3) = at(0, 0) * at(2, 1) * at(3, 2) + at(2, 0) * at(3, 1) * at(0, 2) +
298 at(3, 0) * at(0, 1) * at(2, 2) - at(0, 0) * at(3, 1) * at(2, 2) -
299 at(2, 0) * at(0, 1) * at(3, 2) - at(3, 0) * at(2, 1) * at(0, 2);
300 cof(2, 0) = at(0, 1) * at(1, 2) * at(3, 3) + at(1, 1) * at(3, 2) * at(0, 3) +
301 at(3, 1) * at(0, 2) * at(1, 3) - at(0, 1) * at(3, 2) * at(1, 3) -
302 at(1, 1) * at(0, 2) * at(3, 3) - at(3, 1) * at(1, 2) * at(0, 3);
303 cof(2, 1) = -(at(0, 0) * at(1, 2) * at(3, 3) + at(1, 0) * at(3, 2) * at(0, 3) +
304 at(3, 0) * at(0, 2) * at(1, 3) - at(0, 0) * at(3, 2) * at(1, 3) -
305 at(1, 0) * at(0, 2) * at(3, 3) - at(3, 0) * at(1, 2) * at(0, 3));
306 cof(2, 2) = at(0, 0) * at(1, 1) * at(3, 3) + at(1, 0) * at(3, 1) * at(0, 3) +
307 at(3, 0) * at(0, 1) * at(1, 3) - at(0, 0) * at(3, 1) * at(1, 3) -
308 at(1, 0) * at(0, 1) * at(3, 3) - at(3, 0) * at(1, 1) * at(0, 3);
309 cof(2, 3) = -(at(0, 0) * at(1, 1) * at(3, 2) + at(1, 0) * at(3, 1) * at(0, 2) +
310 at(3, 0) * at(0, 1) * at(1, 2) - at(0, 0) * at(3, 1) * at(1, 2) -
311 at(1, 0) * at(0, 1) * at(3, 2) - at(3, 0) * at(1, 1) * at(0, 2));
312 cof(3, 0) = -(at(0, 1) * at(1, 2) * at(2, 3) + at(1, 1) * at(2, 2) * at(0, 3) +
313 at(2, 1) * at(0, 2) * at(1, 3) - at(0, 1) * at(2, 2) * at(1, 3) -
314 at(1, 1) * at(0, 2) * at(2, 3) - at(2, 1) * at(1, 2) * at(0, 3));
315 cof(3, 1) = at(0, 0) * at(1, 2) * at(2, 3) + at(1, 0) * at(2, 2) * at(0, 3) +
316 at(2, 0) * at(0, 2) * at(1, 3) - at(0, 0) * at(2, 2) * at(1, 3) -
317 at(1, 0) * at(0, 2) * at(2, 3) - at(2, 0) * at(1, 2) * at(0, 3);
318 cof(3, 2) = -(at(0, 0) * at(1, 1) * at(2, 3) + at(1, 0) * at(2, 1) * at(0, 3) +
319 at(2, 0) * at(0, 1) * at(1, 3) - at(0, 0) * at(2, 1) * at(1, 3) -
320 at(1, 0) * at(0, 1) * at(2, 3) - at(2, 0) * at(1, 1) * at(0, 3));
321 cof(3, 3) = at(0, 0) * at(1, 1) * at(2, 2) + at(1, 0) * at(2, 1) * at(0, 2) +
322 at(2, 0) * at(0, 1) * at(1, 2) - at(0, 0) * at(2, 1) * at(1, 2) -
323 at(1, 0) * at(0, 1) * at(2, 2) - at(2, 0) * at(1, 1) * at(0, 2);
324 break;
325
326 default:
327 UNREACHABLE();
328 break;
329 }
330
331 // The inverse of A is the transpose of the cofactor matrix times the reciprocal of the
332 // determinant of A.
333 Matrix<T> adjugateMatrix(cof.transpose());
334 T det = determinant();
335 Matrix<T> result(std::vector<T>(mElements.size()), rows(), columns());
336 for (unsigned int i = 0; i < rows(); i++)
337 for (unsigned int j = 0; j < columns(); j++)
338 result(i, j) = det ? adjugateMatrix(i, j) / det : T();
339
340 return result;
341 }
342
343 void setToIdentity()
344 {
345 ASSERT(rows() == columns());
346
347 const auto one = T(1);
348 const auto zero = T(0);
349
350 for (auto &e : mElements)
351 e = zero;
352
353 for (unsigned int i = 0; i < rows(); ++i)
354 {
355 const auto pos = i * columns() + (i % columns());
356 mElements[pos] = one;
357 }
358 }
359
360 template <unsigned int Size>
361 static void setToIdentity(T (&matrix)[Size])
362 {
363 static_assert(gl::iSquareRoot<Size>() != 0, "Matrix is not square.");
364
365 const auto cols = gl::iSquareRoot<Size>();
366 const auto one = T(1);
367 const auto zero = T(0);
368
369 for (auto &e : matrix)
370 e = zero;
371
372 for (unsigned int i = 0; i < cols; ++i)
373 {
374 const auto pos = i * cols + (i % cols);
375 matrix[pos] = one;
376 }
377 }
378
379 protected:
380 std::vector<T> mElements;
381 unsigned int mRows;
382 unsigned int mCols;
383};
384
385class Mat4 : public Matrix<float>
386{
387 public:
388 Mat4();
389 Mat4(const Matrix<float> generalMatrix);
390 Mat4(const std::vector<float> &elements);
391 Mat4(const float *elements);
392 Mat4(float m00,
393 float m01,
394 float m02,
395 float m03,
396 float m10,
397 float m11,
398 float m12,
399 float m13,
400 float m20,
401 float m21,
402 float m22,
403 float m23,
404 float m30,
405 float m31,
406 float m32,
407 float m33);
408
409 static Mat4 Rotate(float angle, const Vector3 &axis);
410 static Mat4 Translate(const Vector3 &t);
411 static Mat4 Scale(const Vector3 &s);
412 static Mat4 Frustum(float l, float r, float b, float t, float n, float f);
413 static Mat4 Perspective(float fov, float aspectRatio, float n, float f);
414 static Mat4 Ortho(float l, float r, float b, float t, float n, float f);
415
416 Mat4 product(const Mat4 &m);
417 Vector4 product(const Vector4 &b);
418 void dump();
419};
420
421} // namespace angle
422
423#endif // COMMON_MATRIX_UTILS_H_
424