From e13458d793ce57e3381451db5a2bdcf1f585fed1 Mon Sep 17 00:00:00 2001 From: Amy Retzerau Date: Sat, 27 Dec 2025 23:50:41 +0100 Subject: [PATCH] feat: added matrix to mathlib --- fastmath.hpp | 262 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 159 insertions(+), 103 deletions(-) diff --git a/fastmath.hpp b/fastmath.hpp index 7fe619d..1bec42c 100644 --- a/fastmath.hpp +++ b/fastmath.hpp @@ -113,7 +113,7 @@ template struct vec { for (int i = 0; i < n; i++) { newV.v[i] = v1.v[i] - v2.v[i]; } - return static_cast(newV); + return newV; } friend std::ostream &operator<<(std::ostream &os, const vec &v) { @@ -192,6 +192,14 @@ template struct vec { decimal f = decimal(1.0) / this->len(); return (*this * f); } + + constexpr static Dev zero() { + Dev newV = {}; + for (int i = 0; i < n; i++) { + newV[i] = decimal(0); + } + return newV; + } }; struct vec2 : public vec<2, vec2> { @@ -233,6 +241,8 @@ struct vec3 : public vec<3, vec3> { }; struct vec4 : public vec<4, vec4> { + constexpr vec4() : vec<4, vec4>() {} + vec4(float x, float y, float z, float w) : vec<4, vec4>(decimal(x), decimal(y), decimal(z), decimal(w)) {} @@ -242,6 +252,7 @@ struct vec4 : public vec<4, vec4> { vec4(int32_t x, int32_t y, int32_t z, int32_t w) : vec<4, vec4>(decimal(x), decimal(y), decimal(z), decimal(w)) {} + vec4(vec3 v, decimal w) : vec<4, vec4>(v.x(), v.y(), v.z(), w) {} decimal &x() { return v[0]; } decimal &y() { return v[1]; } @@ -249,107 +260,152 @@ struct vec4 : public vec<4, vec4> { decimal &w() { return v[3]; } }; -// template struct mat { -// -// mat(decimal newM[n * n]) { -// for (int i = 0; i < n * n; i++) { -// m[i] = newM[i]; -// } -// } -// -// mat(std::vector newM) { -// for (int i = 0; i < n * n; i++) { -// m[i] = newM[i]; -// } -// } -// -// mat() : m{} {} -// -// friend Dev operator+(const mat &m1, const mat &m2) { -// Dev newM = {}; -// -// for (int i = 0; i < n * n; i++) { -// newM.v[i] = m1.m[i] + m2.m[i]; -// } -// return newM; -// } -// friend Dev operator+=(const mat &m1, const mat &m2) { -// Dev newM = {}; -// -// for (int i = 0; i < n * n; i++) { -// newM.m[i] = m1.m[i] + m2.m[i]; -// } -// return newM; -// } -// -// friend Dev operator-(const mat &m1, const mat &m2) { -// Dev newM = {}; -// -// for (int i = 0; i < n * n; i++) { -// newM.m[i] = m1.m[i] - m2.m[i]; -// } -// return static_cast(newM); -// } -// -// friend std::ostream &operator<<(std::ostream &os, const mat &m) { -// os << "(" << m.m[0]; -// for (int i = 1; i < n * n; i++) { -// os << ", " << m.m[i]; -// } -// return (os << ")" << std::endl); -// } -// -// friend Dev operator*(const mat &m, const decimal &d) { -// int32_t f = d.i >> HALF_SHIFT; -// -// Dev newM = {}; -// for (int i = 0; i < n * n; i++) { -// newM.m[i] = (m.m[i].i >> HALF_SHIFT) * f; -// } -// return newM; -// } -// -// friend Dev operator*(const decimal &d, const mat &v) { -// return v * d; -// } -// -// Dev operator*(const mat &mat) { -// Dev newM = {}; -// for (int i = 0; i < n; i++) { -// for (int j = 0; j < n; j++) { -// newM.m += mat.v[i * n] * m[i]; -// } -// } -// return res; -// } -// -// friend bool operator==(const mat &v1, const mat &m2) { -// bool res = true; -// for (int i = 0; i < n; i++) { -// res &= v1.v[i] == m2.v[i]; -// } -// return res; -// } -// bool isSmall() { -// for (int i = 0; i < n; i++) { -// if (!v[i].isSmall()) -// return false; -// } -// return true; -// } -// decimal &operator[](const int &i) { return v[i]; } -// -// decimal len_sq() { return *this * *this; } -// -// decimal len() { return this->len_sq().sqrt(); } -// -// Dev normalize() { -// decimal f = decimal(1.0) / this->len(); -// return (*this * f); -// } -// -// protected: -// decimal m[n * n]; -// }; +template struct mat { + decimal m[n * n]; + + static const int size = n; + friend Dev operator+(const mat &m1, const mat &m2) { + Dev newM = {}; + + for (int i = 0; i < n * n; i++) { + newM.v[i] = m1.m[i] + m2.m[i]; + } + return newM; + } + friend Dev operator+=(const mat &m1, const mat &m2) { + Dev newM = {}; + + for (int i = 0; i < n * n; i++) { + newM.m[i] = m1.m[i] + m2.m[i]; + } + return newM; + } + + friend Dev operator-(const mat &m1, const mat &m2) { + Dev newM = {}; + + for (int i = 0; i < n * n; i++) { + newM.m[i] = m1.m[i] - m2.m[i]; + } + return newM; + } + + friend std::ostream &operator<<(std::ostream &os, const mat &m) { + for (int i = 0; i < n; i++) { + os << "|" << m.m[i * n]; + for (int j = 1; j < n; j++) { + os << ", " << m.m[i * n + j]; + } + os << "|" << "\n"; + } + return (os << std::endl); + } + + template + friend Dev1 operator*(const mat &mat, const vec &v) { + Dev1 newV = vec::zero(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + newV[i] += mat.m[i * n + j] * v.v[j]; + } + } + return newV; + } + + decimal &operator[](const int &i) { return m[i]; } + + friend Dev operator*(const mat &m1, const mat &m2) { + Dev newM = mat::zero(); + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + for (int k = 0; k < n; k++) { + newM[i * n + j] += m1.m[i * n + k] * m2.m[k * n + j]; + } + } + } + return newM; + } + + constexpr static Dev identity() { + Dev newM = {}; + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + if (i == j) + newM.m[i * n + i] = decimal(1.0f); + else + newM.m[i * n + j] = decimal(0.0f); + } + } + return newM; + } + constexpr static Dev zero() { + Dev newM = {}; + for (int i = 0; i < n * n; i++) { + newM[i] = decimal(0); + } + return newM; + } + + inline void set(int x, int y, decimal v) { m[y * n + x] = v; } + inline decimal get(int x, int y) { return m[y * n + x]; } + + friend bool operator==(const mat &m1, const mat &m2) { + bool res = true; + for (int i = 0; i < n * n; i++) { + res &= m1.m[i] == m2.m[i]; + } + return res; + } + bool isSmall() { + for (int i = 0; i < n; i++) { + if (!m[i].isSmall()) + return false; + } + return true; + } + + template Dev1 cutTo() const { + static_assert(Dev1::size < n, "Can only convert to smaller matrix"); + Dev1 newM = mat::zero(); + for (int i = 0; i < Dev1::size; i++) { + for (int j = 0; j < Dev1::size; j++) { + newM.m[Dev1::size * i + j] = m[n * i + j]; + } + } + return newM; + } +}; + +template struct matN : public mat> {}; + +struct mat3 : public mat<3, mat3> {}; + +struct mat4 : public mat<4, mat4> { + static mat4 translation(const vec3 &v) { + mat4 newM = mat4::identity(); + for (int i = 0; i < 3; i++) { + newM[4 * i + 3] = v.v[i]; + } + return newM; + } + static mat4 rotateOnX(float a) { + mat4 newM = mat4::identity(); + newM.m[1 * 4 + 1] = cos(a), newM.m[2 * 4 + 2] = cos(a); + newM.m[1 * 4 + 2] = -sin(a), newM.m[2 * 4 + 1] = sin(a); + return newM; + } + static mat4 rotateOnY(float a) { + mat4 newM = mat4::identity(); + newM.m[0 * 4 + 0] = cos(a), newM.m[2 * 4 + 2] = cos(a); + newM.m[0 * 4 + 2] = sin(a), newM.m[2 * 4 + 0] = -sin(a); + return newM; + } + static mat4 rotateOnZ(float a) { + mat4 newM = mat4::identity(); + newM.m[0 * 4 + 0] = cos(a), newM.m[1 * 4 + 1] = cos(a); + newM.m[1 * 4 + 0] = sin(a), newM.m[0 * 4 + 1] = -sin(a); + return newM; + } +}; #endif