qttn.dev

レイトレーシング

タエチャスク ナッチャノン 05241011

WebGPU

基本課題 1 は WebGL を用いてベジェ曲線を描画するものであった。基本課題 2 は Three.js をそのまま使って、逆運動学を実装するものであった。

今回の課題では、WebGPU を使ってレイトレーシングを実装してみる。

WebGPU は、WebGL の後継となる次世代のグラフィックス API であり、より低レベルな制御が可能である。基本的には WebGL と似たような Vertex/Fragment シェーダーの仕組みを持っているが、レイトレーシングはそのような仕組みと違って、各ピクセルごとに光線を追跡していく手法である。

それを WebGPU 上で実装するためには、まずは Compute Shader を使って、Image Buffer を生成し、もう一つのパイプラインに渡して Vertex/Fragment Shader で描画する必要がある。

簡単な描画

まずは Compute Shader はさておき、CPU から Image Buffer を生成して、直接 Vertex/Fragment Shader に渡す方法を試してみる。

class RayTracingRendererHTMLCanvasElement を受け取り、context から format を取得する。ただし、device 関連の初期化からは非同期で行うため、init() メソッドを呼び出す必要がある。

shader は wgsl で書かれ、renderPipelinecomputePipeline の二つのパイプラインを持つ。まずは presentShader を定義する。

// presentShader.wgsl
// struct definitions

@group(0) @binding(0) var<storage, read> imageBuffer: array<vec4<f32>>;
@group(0) @binding(1) var<uniform> camera: Camera;

const XYUV = array<vec4<f32>, 6>(
  vec4(-1.0, 1.0, 0.0, 0.0),   // 0
  vec4(1.0, 1.0, 1.0, 0.0),    // 1
  vec4(-1.0, -1.0, 0.0, 1.0),  // 2
  vec4(1.0, 1.0, 1.0, 0.0),    // 3
    vec4(-1.0, -1.0, 0.0, 1.0), // 4
  vec4(1.0, -1.0, 1.0, 1.0)    // 5
);

@vertex
fn vertexMain(@builtin(vertex_index) vertexIndex: u32) -> VertexOutput {
  var output: VertexOutput;
  output.position = vec4(XYUV[vertexIndex].xy, 0.0, 1.0);
  output.uv = XYUV[vertexIndex].zw;
  return output;
}

@fragment
fn fragmentMain(@location(0) uv: vec2<f32>) -> @location(0) vec4<f32> {
  let x = u32(f32(uv.x) * f32(camera.viewport.x));
  let y = u32(f32(uv.y) * f32(camera.viewport.y));
  let index = clamp(x + y * camera.viewport.x, 0u, camera.viewport.x * camera.viewport.y - 1u);
  return imageBuffer[index];
}

XYUV は、頂点シェーダーで使用される座標を UV 座標に変換するための配列である。可視化すると、以下のような形になる。

0,0
    0_______________1,3
    |               /|
    |             /  |
    |           /    |
    |         /      |
    |       /        |
    |     /          |
    |   /            |
    | /              |
    2,4______________|5
                        1,1

パイプラインは​​ WebGL と似ており、シェーダーをコンパイルして、GPU 上のバッファを作成し、CPU からデータを移動させる。

今回は CameraImageBuffer クラスを定義し、以下の Interface を実装する。

interface StateBuffer<T extends TypedArray> {
  data: T;
  device: GPUDevice;
  buffer: GPUBuffer;
  writeBuffer: () => void;
  createBuffer: () => void;
  destroy: () => void;
}

Camera は現状 viewport の情報のみ保持しており、ImageBuffer は画面のピクセルごとの RGBA 値を保持している。

class Camera implements StateBuffer<Uint32Array> {
  data: Uint32Array = new Uint32Array(2);

  device: GPUDevice;
  buffer!: GPUBuffer;

  private readonly viewportOffset = 0;

  private readonly bufferConfig = () => ({
    label: "Camera Buffer",
    size: this.data.byteLength,
    usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
  });

  constructor(device: GPUDevice, viewport: [number, number]) {
    this.device = device;
    this.viewport = viewport;
    this.createBuffer();
  }

  get viewport(): [number, number] {
    return [this.data[0], this.data[1]];
  }

  set viewport(viewport: [number, number]) {
    this.data.set(viewport, this.viewportOffset);
  }

  createBuffer() {
    this.buffer?.destroy();
    this.buffer = this.device.createBuffer(this.bufferConfig());
  }

  writeBuffer() {
    this.device.queue.writeBuffer(this.buffer, 0, this.data);
  }

  destroy() {
    this.buffer?.destroy();
  }
}

(現状は複雑な形になっているが、プログラムを少しずつ抽象化するともっと簡単になる可能性がある。)

createBuffer は元のバッファが存在すれば破棄し、バイト数によってデバイスのバッファを作成する。

writeBuffer は CPU の TypedArray から GPU の buffer にコピーする。

destroy はアプリケーションが unmount するときに呼ば出される。

最後に render は以下のように定義される。

class RayTracingRenderer {
  // ...
  render() {
    if (!this.isInitialized() || !this.renderPipeline) {
      throw Error(
        "Renderer is not initialized. Please check if it's initialized before calling this method.",
      );
    }

    const commandEncoder = this.device.createCommandEncoder({
      label: "renderEncoder",
    });

    const renderPass = commandEncoder.beginRenderPass(
      this.renderPassDescriptor,
    );

    renderPass.setPipeline(this.renderPipeline);
    renderPass.setBindGroup(0, this.renderBindGroup);
    renderPass.draw(6, 1, 0, 0);
    renderPass.end();

    this.device.queue.submit([commandEncoder.finish()]);
  }
  // ...
}

つまり、上に定義した 6 個の頂点を描画し、device のキューにコマンドを送信する。

CPU ボール

ResizeObserver の設定は省略するが、そのコールバック関数は handleResize である。

handleResize はキャンバスから状態を初期化し、bindGroup を作成する。 bindGroup とは何かというと、WebGL の VertexAttribute と似ており、頂点の性質を GPU のどこに割り当てるか指定するバインディンググループである。

class RayTracingRenderer {
  // ...
  handleResize() {
    if (!this.device) {
      // Non-error return (maybe it's just not initialized yet. pls chill)
      return;
    }

    const canvas = this.canvas;
    const dpr = window.devicePixelRatio || 1;
    const width = Math.max(
      1,
      Math.min(
        Math.floor(canvas.clientWidth * dpr),
        this.device.limits.maxTextureDimension2D,
      ),
    );
    const height = Math.max(
      1,
      Math.min(
        Math.floor(canvas.clientHeight * dpr),
        this.device.limits.maxTextureDimension2D,
      ),
    );
    canvas.width = width;
    canvas.height = height;

    this.state.setFromCanvas(canvas);
    this.bindGroup = this.device.createBindGroup({
      label: "bindGroup",
      layout: this.renderPipeline.getBindGroupLayout(0),
      entries: [
        {
          binding: 0,
          resource: {
            buffer: this.state.imageBuffer.buffer,
          },
        },
        {
          binding: 1,
          resource: {
            buffer: this.state.camera.buffer,
          },
        },
      ],
    });

    this.paint();
  }
}

以下は paint 関数である。この節では、CPU から直接 imageBuffer に書き込むため、このような for ループが使われている。

class RayTracingRenderer {
  // ...
  paint() {
    const centerX = this.canvas.width / 2;
    const centerY = this.canvas.height / 2;
    const radius = 100;

    for (let i = 0; i < this.state.imageBuffer.data.length; i += 4) {
      const x = (i / 4) % this.canvas.width;
      const y = Math.floor(i / 4 / this.canvas.width);
      const dx = x - centerX;
      const dy = y - centerY;
      const distance = Math.sqrt(dx * dx + dy * dy);

      if (distance < radius) {
        this.state.imageBuffer.data[i] = 1; // R
        this.state.imageBuffer.data[i + 1] = 0; // G
        this.state.imageBuffer.data[i + 2] = 0; // B
        this.state.imageBuffer.data[i + 3] = 1; // A
      } else {
        this.state.imageBuffer.data[i] = 0; // R
        this.state.imageBuffer.data[i + 1] = 0; // G
        this.state.imageBuffer.data[i + 2] = 0; // B
        this.state.imageBuffer.data[i + 3] = 1; // A
      }
    }
    this.state.imageBuffer.writeBuffer();
    this.state.camera.writeBuffer();
  }

  // ...
}

こうすると、距離が半径より小さいと赤を描画し、さもないと黒を描画するようなプログラムが作れた。

GPU ボール

次に、Compute Shader を使って GPU 上で Image Buffer を生成してみよう。

追加するのは computeShadercomputePipeline である。

@group(0) @binding(0) var<storage, read_write> imageBuffer: array<vec4<f32>>;
@group(0) @binding(1) var<uniform> camera: Camera;

@compute @workgroup_size(${WORKGROUP_SIZE_X}, ${WORKGROUP_SIZE_Y}, 1)
fn computeMain(@builtin(global_invocation_id) gId: vec3<u32>) {
  let viewport = camera.viewport;

  if (gId.x >= viewport.x || gId.y >= viewport.y) {
    return;
  }

  let pixel = gId.x + gId.y * viewport.x;

  let center = vec2<f32>(viewport) / vec2<f32>(2.0);
  let radius = 100.0;
  let uv = vec2<f32>(f32(gId.x), f32(gId.y));

  let d = distance(uv, center);

  var color: vec4<f32> = vec4<f32>(0.0, 0.0, 0.0, 1.0);

  if (d < radius) {
    color = vec4<f32>(0.0, 1.0, 0.0, 1.0); // Green
  }

  imageBuffer[pixel] = color;
}

次に、WORKGROUP_SIZE_XWORKGROUP_SIZE_Y のサイズを持つ、合計で this.canvas.width と this.canvas.height のピクセル数を持つ workgroup をディスパッチする。

const computePass = commandEncoder.beginComputePass(this.computePassDescriptor);

computePass.setPipeline(this.computePipeline);
computePass.setBindGroup(0, this.computeBindGroup);
computePass.dispatchWorkgroups(
  Math.ceil(this.canvas.width / WORKGROUP_SIZE_X),
  Math.ceil(this.canvas.height / WORKGROUP_SIZE_Y),
);
computePass.end();

WebGPU の魅力は Workgroup を使って、GPU 上で並列処理ができることである。@workgroup_size でワークグループのサイズを指定し、@builtin(global_invocation_id) でグローバルな実行 IDを取得し、各ピクセルごとに処理を行うことができる。

今回はもう paint() 関数は不要で、それと似たような処理を Compute Shader で行う。

ただし、確実に変化が見えるように、色を赤から緑に変更している。

カメラ

レイトレーシングの要素の一つとして、カメラの位置、向き、視野角などを設定する必要がある。

Camera クラスは、viewport の情報を保持しているが、以下のようなプロパティを追加する。

class Camera implements StateBuffer<Float32Array> {
  readonly viewport: Vector2;
  readonly fovy: number;
  readonly aspect: number;
  readonly position: Vector3;
  readonly direction: Vector3;
  readonly up: Vector3;

  get properties() {
    return [
      { name: "viewport", size: 2 },
      { name: "fovy", size: 1 },
      { name: "aspect", size: 1 },
      { name: "position", size: 4 },
      { name: "direction", size: 4 },
      { name: "up", size: 4 },
    ];
  }
}

次に、交点テストを行うためのシェーダーを追加する。

@compute @workgroup_size(${WORKGROUP_SIZE_X}, ${WORKGROUP_SIZE_Y}, 1)
fn computeMain(@builtin(global_invocation_id) gId: vec3<u32>) {
  let viewport = vec2<u32>(camera.viewport);

  if (gId.x >= viewport.x || gId.y >= viewport.y) {
    return;
  }

  let pixel = gId.x + gId.y * viewport.x;

  let center = camera.viewport / vec2<f32>(2.0);
  let uv = vec2<f32>(f32(gId.x), f32(gId.y));

  if (renderMode == 0) {
    // 以上の緑ボールの処理
  }

  if (renderMode == 1) {
    let origin = camera.position;
    let direction = normalize(camera.direction);
    let up = normalize(camera.up);
    let right = normalize(cross(direction, up));
    let fovScale = tan(camera.fovy / 2.0);
    let aspect = camera.aspect;

    // tan(fov / 2) unit * aspect
    // _______________________
    // |                     |
    // |  x_______.          |
    // |__|_______0          | tan(fov / 2) unit
    // |          |          |
    // |          |          |
    // |__________|__________|
    //            | 1 unit
    //            |
                                                          // shift the center
             // but the fovScale is in the range of [-1, 1] so * 2
                                      // normalize to 1 unit
                     // center the ray inside the pixel
    let Px = (2.0 * (f32(gId.x) + 0.5) / camera.viewport.x - 1.0);
    let Py = (1.0 - 2.0 * (f32(gId.y) + 0.5) / camera.viewport.y);
    // Py is the same but inverted because uv.y 0 starts from the top left corner

    let x = Px * fovScale * aspect;
    let y = Py * fovScale;

    let rayDirection = normalize(
      direction + x * right + y * up
    );

    //                ____
    //          a ____   \b__----___
    //       ____         /\\        \
    //   <___            (  \\r       )
    // x <-------oc-----(------        )
    //                   (            )
    //                    \\---___---//
    //
    //

    let oc = origin - circleCenter;
    let a = dot(oc, rayDirection);
    let b = dot(oc, oc) - a * a - circleRadius * circleRadius;

    let hit = b < 0.0 && a < 0.0;

    var color: vec4<f32>;
    if (hit) {
      color = circleColor;
    } else {
      color = vec4<f32>(0.0, 0.0, 0.0, 1.0);
    }

    imageBuffer[pixel] = color;
  }
}

ASCII アートが綺麗すぎてそのまま使うようにしている。

複数のオブジェクト

次に、複数のオブジェクトを描画するために、SceneObject クラスを定義する。

class Sphere {
  position: Vector3;
  radius: number;
  color: Vector4;

  data: Float32Array;
  static readonly size = 8;

  constructor(
    position: Vector3 = new Vector3(0, 0, 0),
    radius: number = 20,
    color: Vector4 = new Vector4(0, 0, 1.0, 1.0),
  ) {
    this.position = position;
    this.radius = radius;
    this.color = color;
    this.data = new Float32Array([
      ...this.position.toArray(),
      this.radius,
      ...this.color.toArray(),
    ]);
  }
}

class SceneObjectState implements StateBuffer<Float32Array> {
  data: Float32Array;
  device: GPUDevice;
  buffer: GPUBuffer;
  objects: Sphere[];

  private readonly bufferConfig = () => ({
    label: "Scene Object Buffer",
    size: this.data.byteLength,
    usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
  });

  constructor(device: GPUDevice, objects: Sphere[] = []) {
    this.device = device;
    this.objects = objects;
    this.data = new Float32Array(objects.length * Sphere.size);
    this.buffer = this.createBuffer();
    this.writeBuffer();
  }

  addObject(object: Sphere) {
    this.objects.push(object);
    this.resizeBuffer();
    this.writeBuffer();
  }

  resizeBuffer() {
    const newSize = this.objects.length * Sphere.size;
    if (this.data.length < newSize) {
      const newData = new Float32Array(newSize);
      newData.set(this.data);
      this.data = newData;
    }
  }

  writeBuffer() {
    for (let i = 0; i < this.objects.length; i++) {
      const object = this.objects[i];
      const offset = i * Sphere.size;
      this.data.set(object.data, offset);
    }
    this.device.queue.writeBuffer(this.buffer, 0, this.data);
  }

  createBuffer() {
    this.buffer?.destroy();
    this.buffer = this.device.createBuffer(this.bufferConfig());
    return this.buffer;
  }

  destroy() {
    this.buffer?.destroy();
  }
}

今回は Sphere を中心にしているが、今度の展望として複数のオブジェクト型を持つことができるようにしたいと思う。

WGSL は動的な配列をサポートしていないため、SceneObjectStateFloat32Array を使って、必要に応じてサイズを変更する。次に、RayTracingStateSceneObjectState を追加する。

class RayTracingState {
  imageBuffer: ImageBuffer;
  camera: Camera;
  renderMode: RenderMode;
  objects: SceneObjectState; // NEW

  constructor(canvas: HTMLCanvasElement, device: GPUDevice) {
    this.imageBuffer = new ImageBuffer(device, canvas.width, canvas.height);
    this.camera = new Camera(device, [canvas.width, canvas.height]);
    this.renderMode = new RenderMode(device);

    // NEW
    this.objects = new SceneObjectState(device, [
      new Sphere(new Vector3(0, 0, 0), 10, new Vector4(0.8, 0.8, 0.3, 1.0)),
      new Sphere(new Vector3(15, 0, 5), 15, new Vector4(0.8, 0.3, 0.8, 1.0)),
      new Sphere(new Vector3(-20, -0, 0), 12, new Vector4(0.3, 0.3, 0.8, 1.0)),
    ]);
  }
  // ...
}

バッファの書き込みや破棄、バインドグループの作成は省略する。

  // ...
    var color = vec4<f32>(0.3, 0.6, 0.8, 1.0);
    var t = 1000.0;

    for (var i = 0u; i < arrayLength(&objects); i++) {
      let sphere = objects[i];
      let oc = sphere.center - origin; // it's confusing when it's negative. i had to change
      let a = dot(oc, rayDirection);
      let b = dot(oc, oc) - a * a;

      if (b > sphere.radius * sphere.radius) {
        continue;
      }

      let d = sqrt(sphere.radius * sphere.radius - b);
      var t0 = a - d; // near intersection
      let t1 = a + d; // far intersection

      if (t0 < 0.0 && t1 < 0.0) {
        continue; // we are behind the sphere
      }

      if (t0 < 0.0) {
        t0 = t1; // we are behind the near intersection, take the far one
      }

      if (t0 < t) {
        t = t0; // we found a closer intersection
        color = sphere.color; // use the sphere color
      }
    }

    imageBuffer[pixel] = color;
  //...

前のバージョンは a を負の値にしていたが、時間単位として使いたいので、混乱を避けるために a を正の値にした。また、空の色を vec4<f32>(0.3, 0.6, 0.8, 1.0) として、背景色を設定している。

Diffuse Lighting

次に、拡散反射を実装してみる。

まずは、少しリファクタリングを行う。


fn generateRay(
  camera: Camera,
  uv: vec2<f32>,
) -> Ray {
  let origin = camera.position;
  let direction = normalize(camera.direction);
  let up = normalize(camera.up);
  let right = normalize(cross(direction, up));
  let fovScale = tan(camera.fovy / 2.0);
  let aspect = camera.aspect;

  // tan(fov / 2) unit * aspect
  // _______________________
  // |                     |
  // |  x_______.          |
  // |__|_______0          | tan(fov / 2) unit
  // |          |          |
  // |          |          |
  // |__________|__________|
  //            | 1 unit
  //            |
                                                        // shift the center
            // but the fovScale is in the range of [-1, 1] so * 2
                                    // normalize to 1 unit
                    // center the ray inside the pixel
  let Px = (2.0 * (uv.x + 0.5) / camera.viewport.x - 1.0);
  let Py = (1.0 - 2.0 * (uv.y + 0.5) / camera.viewport.y);
  // Py is the same but inverted because uv.y 0 starts from the top left corner

  let x = Px * fovScale * aspect;
  let y = Py * fovScale;

  let rayDirection = normalize(
    direction + x * right + y * up
  );

  return Ray(origin, rayDirection);
}

fn sphereIntersect(
  ray: Ray,
  sphere: Sphere,
) -> Intersection {
  let oc = sphere.center - ray.origin;
  let a = dot(oc, ray.direction);
  let b = dot(oc, oc) - a * a - sphere.radius * sphere.radius;
  var hit = false;
  var distance = 0.0;
  var position = vec3<f32>(0.0, 0.0, 0.0);
  var normal = vec3<f32>(0.0, 0.0, 0.0);

  if (b < 0.0 && a > 0.0) {
    hit = true;
    let d = sqrt(sphere.radius * sphere.radius - b);
    let t0 = a - d; // near intersection
    let t1 = a + d; // far intersection
    if (t0 < 0.0 && t1 < 0.0) {
      hit = false; // we are behind the sphere
    } else if (t0 < 0.0) {
      distance = t1; // we are behind the near intersection, take the far one
    } else {
      distance = t0; // we found a closer intersection
    }
    position = ray.origin + distance * ray.direction;
    normal = normalize(position - sphere.center);
  }

  return Intersection(normal, distance, sphere.color, position, hit);
}

以上のヘルパー関数により、もっと簡潔にロジックを記述できるようになった。

Render Mode が 3 になった場合のロジックは以下のようになる。

if (renderMode >= 2) {
    let ray = generateRay(camera, uv);
    let origin = ray.origin;
    let rayDirection = ray.direction;

    var color = vec4<f32>(0.3, 0.6, 0.8, 1.0);
    var t = 1000.0;

    for (var i = 0u; i < arrayLength(&objects); i++) {
      let sphere = objects[i];
      let intersection = sphereIntersect(ray, sphere);

      if (intersection.hit && intersection.distance < t) {
        t = intersection.distance;

        if (renderMode == 2) {
          color = intersection.color;
        } else if (renderMode == 3) {
          let ambient = 0.25;
          let lightDirection = normalize(vec3<f32>(1.0, 1.0, 1.0));
          let lightIntensity = max(dot(intersection.normal, lightDirection), 0.0);
          let diffuse = ambient + (1.0 - ambient) * lightIntensity;
          color = vec4<f32>(
            intersection.color.rgb * diffuse,
            intersection.color.a
          );
        }
      }
    }

    imageBuffer[pixel] = color;
  }

拡散反射の計算は、ambient light と directional light によって光の強さを計算する。

次に、床を追加してみる。

床は比較的に簡単で、下に向いていればヒット、上に向いていればミスとすればよい。

if (renderMode >= 4) {
  // Floor intersection
  let t = -origin.y / rayDirection.y;
  if (t > 0.0) {
    let position = origin + t * rayDirection;
    let gridX = floor(position.x / floorGridSize);
    let gridY = floor(position.z / floorGridSize);
    let isEven = (gridX + gridY) % 2 == 0;
    if (isEven) {
      color = floorBaseColor;
    } else {
      color = floorAccentColor;
    }
    ints = Intersection(
      floorNormal,
      t,
      color,
      position,
      true
    );
  }
}

無駄の計算を削減とオブジェクトの抽象化

今までの実装では、Intersection で保管している情報は多すぎる。

  1. position はいつでも同じ式で ray.origin + ray.direction * distance と計算できる。distance はすでにあるため、position は不要である。
  2. normal は position がわかれば、後でオベジェクトから計算できる。
  3. ヒットかどうかは t の値で判断できる
  4. color もオブジェクトごとに決まっている。

すなわち、Intersection は以下のように簡略化できる。

struct IntersectionV2 {
  t: f32,
  e: u32,
}

e はオブジェクトの id である。(これからエンティティと呼ぶ) t はヒットした距離である。

raycast 関数も同様に、球体だけでなく床、トーラスなどのオブジェクトを扱えるように、オブジェクト ID で識別し、計算方法を分岐する。

まずは、SceneObjectSphere をさらに抽象化し、ECS のようなコンポーネントを定義する。

struct Position {
  value: vec3<f32>, // x, y, z,
  _padding: f32, // padding to align to 16 bytes
}

struct Color {
  value: vec4<f32>, // r, g, b, a
}

struct SphereAttribute {
  radius: f32, // radius
}

struct EntityMetadata {
  position: u32,
  color: u32,
  sphere: u32,
}

以上が、今回使われるコンポーネントである。

CPU 側では、コンポーネントのクラスはこのように定義される。

class ColorComponent extends Vector4 {
  static readonly size = 4;
  static readonly name = "Color" as const;
  static readonly dataclass = Float32Array;

  entityRef: Entity | null = null;

  shouldUpdate = true;

  get data() {
    return this.toArray();
  }

  constructor(
    r: number = 1.0,
    g: number = 1.0,
    b: number = 1.0,
    a: number = 1.0,
  ) {
    super(r, g, b, a);
    return new Proxy(this, {
      get: (target, prop) => {
        return (target as any)[prop];
      },
      set: (target, prop, value) => {
        (target as any).shouldUpdate = true;
        (target as any)[prop] = value;
        if ((target as any).entityRef) {
          (target as any).entityRef.shouldUpdate = true;
        }
        return true;
      },
    });
  }
}

すなわち、ベースクラス、エンティティへの参照、更新フラグ、平坦化されたデータを持つ。 さらにクラスの静的なプロパティとして、サイズ、名前、データクラスを定義する。 サイズは 4 の倍数である必要がある。(16 バイトアライメントのため)

もっとも重要な部分は、EntityRegistry のクラスである。

class EntityRegistry {
  device: GPUDevice;
  entityMaxSize = 32;
  entitySize = 0;

  readonly componentSize = Components.length;

  storage: ComponentStorage;

  get totalIndexDataSize() {
    return this.componentSize * this.entityMaxSize;
  }

  entityMetadata = new Uint32Array(
    this.entityMaxSize * this.componentSize,
  ).fill(0);
  entityMetadataBuffer: GPUBuffer;

  entities: Map<number, Entity> = new Map();

  constructor(device: GPUDevice) {
    this.device = device;
    this.storage = Components.reduce((acc, cur) => {
      acc[cur.name] = {
        instances: Array.from({
          length: this.entityMaxSize,
        }) as any,
        data: new cur.dataclass(
          this.entityMaxSize * cur.dataclass.BYTES_PER_ELEMENT,
        ),
        buffer: this.device.createBuffer({
          label: `${cur.name} Buffer`,
          size: cur.dataclass.BYTES_PER_ELEMENT * this.entityMaxSize,
          usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
        }),
      };
      return acc;
    }, {} as ComponentStorage);
    this.entityMetadataBuffer = this.device.createBuffer({
      label: "Entity Metadata Buffer",
      size: this.totalIndexDataSize * Uint32Array.BYTES_PER_ELEMENT,
      usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
    });
  }

  spawn(): Entity;
  spawn(components: Array<ComponentType>): Entity;

  spawn(components?: Array<ComponentType>): Entity {
    const id = this.entitySize++;
    const entity = new Entity(id);
    this.entities.set(id, entity);
    if (components) {
      entity.addComponentBundle(components);
    }
    return entity;
  }

  writeBuffer() {
    for (const entity of this.entities.values()) {
      if (!entity.shouldUpdate) {
        continue;
      }
      const index = entity.id;
      for (let i = 0; i < this.componentSize; i++) {
        this.entityMetadata[index * this.componentSize + i] = 0;
      }
      for (const [componentName, component] of entity.directComponentMap) {
        this.entityMetadata[
          index * this.componentSize + ComponentIds[componentName]
        ] = 1;
        if (component.shouldUpdate) {
          const meta = ComponentMap[componentName];
          const storage = this.storage[componentName];
          const offset = index * meta.size;
          storage.data.set(component.data, offset);
          this.device.queue.writeBuffer(
            storage.buffer,
            offset * meta.dataclass.BYTES_PER_ELEMENT,
            storage.data,
            offset,
            meta.size,
          );
          component.shouldUpdate = false;
        }
      }
      entity.shouldUpdate = false;
    }

    this.device.queue.writeBuffer(
      this.entityMetadataBuffer,
      0,
      this.entityMetadata,
      0,
      this.totalIndexDataSize,
    );
  }

  destroy() {
    for (const storage of Object.values(this.storage)) {
      storage.buffer.destroy();
    }
  }
}

今回は 32個以下のエンティティを想定している。(delete と動的リサイズがあるともっと複雑になる。) EntityRegistry は、エンティティを spawn する際に、ID を割り当て、Metadata のアクティブ状態を更新し、バッファを平坦化して GPU に書き込むライフサイクルの管理を行うクラスである。

エンティティ自体は以下のように定義される。

class Entity {
  id: number;
  directComponentMap: Map<ComponentName, ComponentType> = new Map();
  shouldUpdate = true;

  constructor(id: number) {
    this.id = id;
  }

  addComponent(component: ComponentType) {
    component.shouldUpdate = true;
    component.entityRef = this;
    const componentName = component.constructor.name as ComponentName;

    if (this.directComponentMap.has(componentName)) {
      throw new Error(
        `Component ${componentName} is already added to this entity.`,
      );
    }

    this.directComponentMap.set(componentName, component);
    this.shouldUpdate = true;
    return this;
  }

  addComponentBundle(components: Array<ComponentType>): Entity {
    for (const component of components) {
      this.addComponent(component);
    }
    return this;
  }
}

shouldUpdate のフラグで、エンティティからコンポーネントを更新することができる。

基本的な使い方は

this.entityRegistry = new EntityRegistry(device);
this.entityRegistry.spawn([
  new PositionComponent(0, 10, 0),
  new ColorComponent(0.8, 0.8, 0.3, 1.0),
  new SphereComponent(10),
]);

const entity = this.entityRegistry.spawn([new PositionComponent(15, 15, 5)]);
// 後からコンポーネントを追加することもできる。
entity.addComponentBundle([
  new ColorComponent(0.8, 0.3, 0.8, 1.0),
  new SphereComponent(15),
]);

this.entityRegistry.spawn([
  new PositionComponent(-20, 12, 0),
  new ColorComponent(0.3, 0.3, 0.8, 1.0),
  new SphereComponent(12),
]);

すると、3つのエンティティが生成される。

エンティティもあって、コンポーネントもあって、残りはシステムである。システムはCPU 側であるためもっとも簡単である。

type System = (rd: RayTracingRenderer) => void;

class RayTracingRenderer {
  systems: System[] = [];

  update() {
    for (const system of this.systems) {
      system(this);
    }
  }
}

// react component 内で使う場合
const stateRef = useRef<{
  angle: number;
}>({
  angle: 0,
});

const jumpingUpandDown: System = (rd) => {
  if (!scrollPassed("ecs-component")) {
    return;
  }

  const state = stateRef.current;

  if (state.angle >= 360) {
    state.angle = 0;
  } else {
    state.angle += 0.1;
  }

  const data =
    rd.state.entityRegistry.entities.get(0)?.directComponentMap.Position;
  if (data) {
    data.y = Math.sin(state.angle) * 10 + 10;
  }
};