add the world
This commit is contained in:
parent
5f70c779d0
commit
587c3cb92c
14
src/hitable.cpp
Normal file
14
src/hitable.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "ray.cpp"
|
||||
|
||||
struct hit_record {
|
||||
float t;
|
||||
vec3 p;
|
||||
vec3 normal;
|
||||
};
|
||||
|
||||
class hitable {
|
||||
public:
|
||||
__device__ virtual bool hit(const ray& r, float t_min, float t_max, hit_record& rec) const = 0;
|
||||
};
|
26
src/hitable_list.cpp
Normal file
26
src/hitable_list.cpp
Normal file
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include "hitable.cpp"
|
||||
|
||||
class hitable_list: public hitable {
|
||||
public:
|
||||
__device__ hitable_list() {}
|
||||
__device__ hitable_list(hitable **l, int n) { list = l; list_size = n; }
|
||||
__device__ virtual bool hit(const ray& r, float t_min, float t_max, hit_record& rec) const;
|
||||
hitable **list;
|
||||
int list_size;
|
||||
};
|
||||
|
||||
__device__ bool hitable_list::hit(const ray& r, float t_min, float t_max, hit_record& rec) const {
|
||||
hit_record temp_rec;
|
||||
bool hit_anything = false;
|
||||
float closest_so_far = t_max;
|
||||
for (int i = 0; i < list_size; i++) {
|
||||
if (list[i]->hit(r, t_min, closest_so_far, temp_rec)) {
|
||||
hit_anything = true;
|
||||
closest_so_far = temp_rec.t;
|
||||
rec = temp_rec;
|
||||
}
|
||||
}
|
||||
return hit_anything;
|
||||
}
|
70
src/main.cu
70
src/main.cu
@ -2,9 +2,12 @@
|
||||
#include <ctime>
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <float.h>
|
||||
|
||||
#include "vec3.cpp"
|
||||
#include "ray.cpp"
|
||||
#include "sphere.cpp"
|
||||
#include "hitable_list.cpp"
|
||||
|
||||
#define checkCudaErrors(val) check_cuda((val), #val, __FILE__, __LINE__)
|
||||
|
||||
@ -17,13 +20,31 @@ void check_cuda(cudaError_t result, const char *func, const char *file, int line
|
||||
}
|
||||
}
|
||||
|
||||
__device__ vec3 color(const ray& r) {
|
||||
vec3 unit_direction = unit_vector(r.direction());
|
||||
float t = 0.5f * (unit_direction.y() + 1.0f);
|
||||
return (1.0f - t)*vec3(1.0,1.0,1.0) + t*vec3(1.0, 0.0, 0.0);
|
||||
__device__ bool hit_sphere(const vec3& center, float radius, const ray& r) {
|
||||
vec3 oc = r.origin() - center;
|
||||
float a = dot(r.direction(), r.direction());
|
||||
float b = dot(oc, r.direction());
|
||||
float c = dot(oc, oc) - radius*radius;
|
||||
float discriminant = b*b - a*c;
|
||||
return (discriminant > 0.0);
|
||||
}
|
||||
|
||||
__global__ void render(vec3 *fb, int max_x, int max_y, vec3 lower_left_corner, vec3 horizontal, vec3 vertical, vec3 origin) {
|
||||
__device__ vec3 color(const ray& r, hitable **world) {
|
||||
hit_record rec;
|
||||
if ((*world)->hit(r, 0.0, FLT_MAX, rec)) {
|
||||
return 0.5f*vec3(rec.normal.x()+1.0f, rec.normal.y()+1.0f, rec.normal.z()+1.0f);
|
||||
}
|
||||
|
||||
vec3 unit_direction = unit_vector(r.direction());
|
||||
float t = 0.5f * (unit_direction.y() + 1.0f);
|
||||
return (1.0f - t)*vec3(1.0,1.0,1.0) + t*vec3(0.5, 0.7, 1.0);
|
||||
}
|
||||
|
||||
__global__ void render(vec3 *fb,
|
||||
int max_x, int max_y,
|
||||
vec3 lower_left_corner, vec3 horizontal, vec3 vertical,
|
||||
vec3 origin,
|
||||
hitable **world) {
|
||||
int x = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int y = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
if ((x >= max_x) || (y >= max_y)) return;
|
||||
@ -32,7 +53,22 @@ __global__ void render(vec3 *fb, int max_x, int max_y, vec3 lower_left_corner, v
|
||||
float v = float(y) / max_y;
|
||||
|
||||
ray r(origin, lower_left_corner + u*horizontal + v*vertical);
|
||||
fb[pixel_idx] = color(r);
|
||||
fb[pixel_idx] = color(r, world);
|
||||
}
|
||||
|
||||
__global__ void create_world(hitable **d_list, int d_list_size, hitable **d_world) {
|
||||
if (threadIdx.x == 0 && blockIdx.y == 0) {
|
||||
d_list[0] = new sphere(vec3(0,0, -1), 0.5);
|
||||
d_list[1] = new sphere(vec3(0,-100.5, -1), 100);
|
||||
*d_world = new hitable_list(d_list, d_list_size);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void free_world(hitable **d_list, int d_list_size, hitable **d_world) {
|
||||
for (int i = 0; i < d_list_size; i++) {
|
||||
delete d_list[i];
|
||||
}
|
||||
delete *d_world;
|
||||
}
|
||||
|
||||
int main() {
|
||||
@ -48,9 +84,19 @@ int main() {
|
||||
int num_pixels = nx*ny;
|
||||
size_t fb_size = num_pixels*sizeof(vec3);
|
||||
|
||||
// allocate frame buffer
|
||||
vec3 *fb;
|
||||
checkCudaErrors(cudaMallocManaged(&fb, fb_size));
|
||||
|
||||
// populate world
|
||||
hitable **d_list;
|
||||
int d_list_size = 2;
|
||||
checkCudaErrors(cudaMalloc((void **)&d_list, d_list_size*sizeof(hitable *)));
|
||||
hitable **d_world;
|
||||
checkCudaErrors(cudaMalloc((void **)&d_world, sizeof(hitable *)));
|
||||
create_world<<<1,1>>>(d_list, d_list_size, d_world);
|
||||
|
||||
// Render frame buffer
|
||||
clock_t start = clock();
|
||||
{
|
||||
dim3 blocks(nx/tx+1, ny/ty+1);
|
||||
@ -60,7 +106,8 @@ int main() {
|
||||
vec3(-2.0, -1.0, -1.0),
|
||||
vec3(4.0, 0.0, 0.0),
|
||||
vec3(0.0, 2.0, 0.0),
|
||||
vec3(0.0, 0.0, 0.0));
|
||||
vec3(0.0, 0.0, 0.0),
|
||||
d_world);
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
}
|
||||
@ -68,6 +115,7 @@ int main() {
|
||||
double timer_seconds = ((double)(stop - start)) / CLOCKS_PER_SEC;
|
||||
std::cout << "took " << timer_seconds << " seconds.\n";
|
||||
|
||||
// Saveing frame buffer
|
||||
FILE *f = fopen(image_filename, "w");
|
||||
assert(f);
|
||||
|
||||
@ -91,7 +139,15 @@ int main() {
|
||||
|
||||
fclose(f);
|
||||
|
||||
// Cleanup
|
||||
free_world<<<1,1>>>(d_list, d_list_size, d_world);
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
checkCudaErrors(cudaFree(d_list));
|
||||
checkCudaErrors(cudaFree(d_world));
|
||||
checkCudaErrors(cudaFree(fb));
|
||||
|
||||
cudaDeviceReset();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "vec3.cpp"
|
||||
|
||||
class ray
|
||||
|
38
src/sphere.cpp
Normal file
38
src/sphere.cpp
Normal file
@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include "hitable.cpp"
|
||||
|
||||
class sphere: public hitable {
|
||||
public:
|
||||
__device__ sphere() {}
|
||||
__device__ sphere(vec3 cen, float r) : center(cen), radius(r) {};
|
||||
__device__ virtual bool hit(const ray& r, float t_min, float t_max, hit_record& rec) const;
|
||||
vec3 center;
|
||||
float radius;
|
||||
};
|
||||
|
||||
__device__ bool sphere::hit(const ray& r, float t_min, float t_max, hit_record& rec) const {
|
||||
vec3 oc = r.origin() - center;
|
||||
float a = dot(r.direction(), r.direction());
|
||||
float b = dot(oc, r.direction());
|
||||
float c = dot(oc, oc) - radius*radius;
|
||||
float discriminant = b*b - a*c;
|
||||
if (discriminant > 0) {
|
||||
float temp = (-b - sqrt(discriminant))/a;
|
||||
if (temp < t_max && temp > t_min) {
|
||||
rec.t = temp;
|
||||
rec.p = r.point_at_parameter(rec.t);
|
||||
rec.normal = (rec.p - center) / radius;
|
||||
return true;
|
||||
}
|
||||
|
||||
temp = (-b + sqrt(discriminant))/a;
|
||||
if (temp < t_max && temp > t_min) {
|
||||
rec.t = temp;
|
||||
rec.p = r.point_at_parameter(rec.t);
|
||||
rec.normal = (rec.p - center) / radius;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
Loading…
Reference in New Issue
Block a user