Files
SoftwareRenderer/fastmath.hpp
2025-12-20 22:03:34 +01:00

375 lines
10 KiB
C++

#ifndef FASTMATH_H
#define FASTMATH_H
#include <iostream>
#include <math.h>
#include <stdint.h>
#include <stdlib.h>
#include <vector>
#include <iostream>
#define SHIFT_AMOUNT 16
#define HALF_SHIFT (SHIFT_AMOUNT / 2)
#define SHIFT_MASK ((1 << SHIFT_AMOUNT) - 1)
#define TO_FLOAT(x) \
(((float)(x >> SHIFT_AMOUNT)) + \
((double)(x & SHIFT_MASK) / (1 << SHIFT_AMOUNT)))
#define TO_INT(x) ((int32_t)(x * (1 << SHIFT_AMOUNT)))
#define MUL_F(a, b) (((a) >> HALF_SHIFT) * ((b) >> HALF_SHIFT))
#define DIV_F(a, b) ((((a) << HALF_SHIFT) / (b)) << HALF_SHIFT)
struct decimal {
int32_t i;
decimal() : i(0) {}
decimal(float i) : i(TO_INT(i)) {}
decimal(double i) : i(TO_INT(i)) {}
decimal(int32_t i) : i(i) {}
friend std::ostream& operator<<(std::ostream& os, const decimal &d) {
return (os << TO_FLOAT(d.i));
}
friend decimal operator+(const decimal &d1, const decimal &d2) {
return {d1.i + d2.i};
}
decimal &operator+=(const decimal &d) { return (*this) = {i + d.i}; }
friend decimal operator-(const decimal &d1, const decimal &d2) {
return {d1.i - d2.i};
}
friend decimal operator-(const decimal &d) { return {-d.i}; }
friend decimal operator*(const decimal &d1, const decimal &d2) {
return {MUL_F(d1.i, d2.i)};
}
decimal &operator*=(const decimal &d) { return (*this) = {MUL_F(i, d.i)}; }
friend decimal operator/(const decimal &d1, const decimal &d2) {
return {DIV_F(d1.i, d2.i)};
}
friend bool operator<(const decimal &d1, const decimal &d2) {
return d1.i < d2.i;
}
friend bool operator>(const decimal &d1, const decimal &d2) {
return d1.i > d2.i;
}
friend bool operator<=(const decimal &d1, const decimal &d2) {
return d1.i <= d2.i;
}
friend bool operator>=(const decimal &d1, const decimal &d2) {
return d1.i >= d2.i;
}
friend bool operator==(const decimal &d1, const decimal &d2) {
return d1.i == d2.i;
}
friend bool operator!=(const decimal &d1, const decimal &d2) {
return d1.i != d2.i;
}
decimal &operator=(decimal const &in) {
if (this != &in) {
std::destroy_at(this);
std::construct_at(this, in);
}
return *this;
}
decimal sqrt() { return {((int32_t)sqrtf(i)) << HALF_SHIFT}; }
float to_float() { return TO_FLOAT(i); }
bool isSmall() { return (abs(i) < (1 << (HALF_SHIFT - 1))); }
};
template <int n, class Dev> struct vec {
vec(decimal newV[n]) {
for (int i = 0; i < n; i++) {
v[i] = newV[i];
}
}
vec(std::vector<decimal> newV) {
for (int i = 0; i < n; i++) {
v[i] = newV[i];
}
}
vec() : v{} {}
friend Dev operator+(const vec<n, Dev> &v1, const vec<n, Dev> &v2) {
Dev newV = {};
for (int i = 0; i < n; i++) {
newV.v[i] = v1.v[i] + v2.v[i];
}
return newV;
}
friend Dev operator+=(const vec<n, Dev> &v1, const vec<n, Dev> &v2) {
Dev newV = {};
for (int i = 0; i < n; i++) {
newV.v[i] = v1.v[i] + v2.v[i];
}
return newV;
}
friend Dev operator-(const vec<n, Dev> &v1, const vec<n, Dev> &v2) {
Dev newV = {};
for (int i = 0; i < n; i++) {
newV.v[i] = v1.v[i] - v2.v[i];
}
return static_cast<Dev>(newV);
}
friend std::ostream &operator<<(std::ostream &os, const vec<n, Dev> &v) {
os << "(" << v.v[0];
for (int i = 1; i < n; i++) {
os << ", " << v.v[i];
}
return (os << ")" << std::endl);
}
Dev operator-() {
Dev newV = {};
for (int i = 0; i < n; i++) {
newV.v[i] = -v[i];
}
return newV;
}
friend Dev operator*(const vec<n, Dev> &v, const decimal &d) {
int32_t f = d.i >> HALF_SHIFT;
Dev newV = {};
for (int i = 0; i < n; i++) {
newV.v[i] = (v.v[i].i >> HALF_SHIFT) * f;
}
return newV;
}
static Dev max(const vec<n, Dev> &v1, const vec<n, Dev> &v2) {
Dev newV = {};
for (int i = 0; i < n; i++) {
newV.v[i] = std::max(v1.v[i], v2.v[i]);
}
return newV;
}
static Dev min(const vec<n, Dev> &v1, const vec<n, Dev> &v2) {
Dev newV = {};
for (int i = 0; i < n; i++) {
newV.v[i] = std::min(v1.v[i], v2.v[i]);
}
return newV;
}
friend Dev operator*(const decimal &d, const vec<n, Dev> &v) {
return v * d;
}
decimal operator*(const vec<n, Dev> &vec) {
decimal res;
for (int i = 0; i < n; i++) {
res += vec.v[i] * v[i];
}
return res;
}
friend bool operator==(const vec<n, Dev> &v1, const vec<n, Dev> &v2) {
bool res = true;
for (int i = 0; i < n; i++) {
res &= v1.v[i] == v2.v[i];
}
return res;
}
bool isSmall() {
for (int i = 0; i < n; i++) {
if (!v[i].isSmall())
return false;
}
return true;
}
decimal &operator[](const int &i) { return v[i]; }
decimal len_sq() { return *this * *this; }
decimal len() { return this->len_sq().sqrt(); }
Dev normalize() {
decimal f = decimal(1.0) / this->len();
return (*this * f);
}
protected:
decimal v[n];
};
struct vec2 : public vec<2, vec2> {
vec2() : vec<2, vec2>() {}
vec2(float x, float y) : vec<2, vec2>({decimal(x), decimal(y)}) {}
vec2(double x, double y) : vec<2, vec2>({decimal(x), decimal(y)}) {}
vec2(int32_t x, int32_t y) : vec<2, vec2>({decimal(x), decimal(y)}) {}
vec2(decimal x, decimal y) : vec<2, vec2>({x, y}) {}
decimal &x() { return v[0]; }
decimal &y() { return v[1]; }
};
struct vec3 : public vec<3, vec3> {
vec3() : vec<3, vec3>() {}
vec3(float x, float y, float z)
: vec<3, vec3>({decimal(x), decimal(y), decimal(z)}) {}
vec3(double x, double y, double z)
: vec<3, vec3>({decimal(x), decimal(y), decimal(z)}) {}
vec3(int32_t x, int32_t y, int32_t z)
: vec<3, vec3>({decimal(x), decimal(y), decimal(z)}) {}
vec3(decimal x, decimal y, decimal z) : vec<3, vec3>({x, y, z}) {}
decimal &x() { return v[0]; }
decimal &y() { return v[1]; }
decimal &z() { return v[2]; }
vec3 cross(vec3 &v) {
return vec3((y() * v.z()) - (z() * v.y()),
(z() * v.x()) - (x() * v.z()),
(x() * v.y()) - (y() * v.x()));
}
};
struct vec4 : public vec<4, vec4> {
vec4() : vec<4, vec4>() {}
vec4(float x, float y, float z, float w)
: vec<4, vec4>({decimal(x), decimal(y), decimal(z), decimal(w)}) {}
vec4(double x, double y, double z, double w)
: vec<4, vec4>({decimal(x), decimal(y), decimal(z), decimal(w)}) {}
vec4(int32_t x, int32_t y, int32_t z, int32_t w)
: vec<4, vec4>({decimal(x), decimal(y), decimal(z), decimal(w)}) {}
vec4(decimal x, decimal y, decimal z) : vec<4, vec4>({x, y, z}) {}
decimal &x() { return v[0]; }
decimal &y() { return v[1]; }
decimal &z() { return v[2]; }
decimal &w() { return v[3]; }
};
// template <int n, class Dev> struct mat {
//
// mat(decimal newM[n * n]) {
// for (int i = 0; i < n * n; i++) {
// m[i] = newM[i];
// }
// }
//
// mat(std::vector<decimal> newM) {
// for (int i = 0; i < n * n; i++) {
// m[i] = newM[i];
// }
// }
//
// mat() : m{} {}
//
// friend Dev operator+(const mat<n, Dev> &m1, const mat<n, Dev> &m2) {
// Dev newM = {};
//
// for (int i = 0; i < n * n; i++) {
// newM.v[i] = m1.m[i] + m2.m[i];
// }
// return newM;
// }
// friend Dev operator+=(const mat<n, Dev> &m1, const mat<n, Dev> &m2) {
// Dev newM = {};
//
// for (int i = 0; i < n * n; i++) {
// newM.m[i] = m1.m[i] + m2.m[i];
// }
// return newM;
// }
//
// friend Dev operator-(const mat<n, Dev> &m1, const mat<n, Dev> &m2) {
// Dev newM = {};
//
// for (int i = 0; i < n * n; i++) {
// newM.m[i] = m1.m[i] - m2.m[i];
// }
// return static_cast<Dev>(newM);
// }
//
// friend std::ostream &operator<<(std::ostream &os, const mat<n, Dev> &m) {
// os << "(" << m.m[0];
// for (int i = 1; i < n * n; i++) {
// os << ", " << m.m[i];
// }
// return (os << ")" << std::endl);
// }
//
// friend Dev operator*(const mat<n, Dev> &m, const decimal &d) {
// int32_t f = d.i >> HALF_SHIFT;
//
// Dev newM = {};
// for (int i = 0; i < n * n; i++) {
// newM.m[i] = (m.m[i].i >> HALF_SHIFT) * f;
// }
// return newM;
// }
//
// friend Dev operator*(const decimal &d, const mat<n, Dev> &v) {
// return v * d;
// }
//
// Dev operator*(const mat<n, Dev> &mat) {
// Dev newM = {};
// for (int i = 0; i < n; i++) {
// for (int j = 0; j < n; j++) {
// newM.m += mat.v[i * n] * m[i];
// }
// }
// return res;
// }
//
// friend bool operator==(const mat<n, Dev> &v1, const mat<n, Dev> &m2) {
// bool res = true;
// for (int i = 0; i < n; i++) {
// res &= v1.v[i] == m2.v[i];
// }
// return res;
// }
// bool isSmall() {
// for (int i = 0; i < n; i++) {
// if (!v[i].isSmall())
// return false;
// }
// return true;
// }
// decimal &operator[](const int &i) { return v[i]; }
//
// decimal len_sq() { return *this * *this; }
//
// decimal len() { return this->len_sq().sqrt(); }
//
// Dev normalize() {
// decimal f = decimal(1.0) / this->len();
// return (*this * f);
// }
//
// protected:
// decimal m[n * n];
// };
#endif