405 lines
10 KiB
Markdown
405 lines
10 KiB
Markdown
# TSL Compute Shaders
|
|
|
|
Compute shaders run on the GPU for parallel processing of data. TSL makes them accessible through JavaScript.
|
|
|
|
## Basic Setup
|
|
|
|
```javascript
|
|
import * as THREE from 'three/webgpu';
|
|
import { Fn, instancedArray, instanceIndex, vec3 } from 'three/tsl';
|
|
|
|
// Create storage buffer
|
|
const count = 100000;
|
|
const positions = instancedArray(count, 'vec3');
|
|
|
|
// Create compute shader
|
|
const computeShader = Fn(() => {
|
|
const position = positions.element(instanceIndex);
|
|
position.x.addAssign(0.01);
|
|
})().compute(count);
|
|
|
|
// Execute
|
|
renderer.compute(computeShader);
|
|
```
|
|
|
|
## Storage Buffers
|
|
|
|
### Instanced Arrays
|
|
|
|
```javascript
|
|
import { instancedArray } from 'three/tsl';
|
|
|
|
// Create typed storage buffers
|
|
const positions = instancedArray(count, 'vec3');
|
|
const velocities = instancedArray(count, 'vec3');
|
|
const colors = instancedArray(count, 'vec4');
|
|
const indices = instancedArray(count, 'uint');
|
|
const values = instancedArray(count, 'float');
|
|
```
|
|
|
|
### Accessing Elements
|
|
|
|
```javascript
|
|
const computeShader = Fn(() => {
|
|
// Get element at current index
|
|
const position = positions.element(instanceIndex);
|
|
const velocity = velocities.element(instanceIndex);
|
|
|
|
// Read values
|
|
const x = position.x;
|
|
const speed = velocity.length();
|
|
|
|
// Write values
|
|
position.assign(vec3(0, 0, 0));
|
|
position.x.assign(1.0);
|
|
position.addAssign(velocity);
|
|
})().compute(count);
|
|
```
|
|
|
|
### Accessing Other Elements
|
|
|
|
```javascript
|
|
const computeShader = Fn(() => {
|
|
const myIndex = instanceIndex;
|
|
const neighborIndex = myIndex.add(1).mod(count);
|
|
|
|
const myPos = positions.element(myIndex);
|
|
const neighborPos = positions.element(neighborIndex);
|
|
|
|
// Calculate distance to neighbor
|
|
const dist = myPos.distance(neighborPos);
|
|
})().compute(count);
|
|
```
|
|
|
|
## Compute Shader Patterns
|
|
|
|
### Initialize Particles
|
|
|
|
```javascript
|
|
const computeInit = Fn(() => {
|
|
const position = positions.element(instanceIndex);
|
|
const velocity = velocities.element(instanceIndex);
|
|
|
|
// Random positions using hash
|
|
position.x.assign(hash(instanceIndex).mul(10).sub(5));
|
|
position.y.assign(hash(instanceIndex.add(1)).mul(10).sub(5));
|
|
position.z.assign(hash(instanceIndex.add(2)).mul(10).sub(5));
|
|
|
|
// Zero velocity
|
|
velocity.assign(vec3(0));
|
|
})().compute(count);
|
|
|
|
// Run once at startup
|
|
await renderer.computeAsync(computeInit);
|
|
```
|
|
|
|
### Physics Update
|
|
|
|
```javascript
|
|
const gravity = uniform(-9.8);
|
|
const deltaTimeUniform = uniform(0);
|
|
const groundY = uniform(0);
|
|
|
|
const computeUpdate = Fn(() => {
|
|
const position = positions.element(instanceIndex);
|
|
const velocity = velocities.element(instanceIndex);
|
|
const dt = deltaTimeUniform;
|
|
|
|
// Apply gravity
|
|
velocity.y.addAssign(gravity.mul(dt));
|
|
|
|
// Update position
|
|
position.addAssign(velocity.mul(dt));
|
|
|
|
// Ground collision
|
|
If(position.y.lessThan(groundY), () => {
|
|
position.y.assign(groundY);
|
|
velocity.y.assign(velocity.y.negate().mul(0.8)); // Bounce
|
|
velocity.xz.mulAssign(0.95); // Friction
|
|
});
|
|
})().compute(count);
|
|
|
|
// In animation loop
|
|
function animate() {
|
|
deltaTimeUniform.value = clock.getDelta();
|
|
renderer.compute(computeUpdate);
|
|
renderer.render(scene, camera);
|
|
}
|
|
```
|
|
|
|
### Attraction to Point
|
|
|
|
```javascript
|
|
const attractorPos = uniform(new THREE.Vector3(0, 0, 0));
|
|
const attractorStrength = uniform(1.0);
|
|
|
|
const computeAttract = Fn(() => {
|
|
const position = positions.element(instanceIndex);
|
|
const velocity = velocities.element(instanceIndex);
|
|
|
|
// Direction to attractor
|
|
const toAttractor = attractorPos.sub(position);
|
|
const distance = toAttractor.length();
|
|
const direction = toAttractor.normalize();
|
|
|
|
// Apply force (inverse square falloff)
|
|
const force = direction.mul(attractorStrength).div(distance.mul(distance).add(0.1));
|
|
velocity.addAssign(force.mul(deltaTimeUniform));
|
|
})().compute(count);
|
|
```
|
|
|
|
### Neighbor Interaction (Boids-like)
|
|
|
|
```javascript
|
|
const computeBoids = Fn(() => {
|
|
const myPos = positions.element(instanceIndex);
|
|
const myVel = velocities.element(instanceIndex);
|
|
|
|
const separation = vec3(0).toVar();
|
|
const alignment = vec3(0).toVar();
|
|
const cohesion = vec3(0).toVar();
|
|
const neighborCount = int(0).toVar();
|
|
|
|
// Check nearby particles
|
|
Loop(count, ({ i }) => {
|
|
If(i.notEqual(instanceIndex), () => {
|
|
const otherPos = positions.element(i);
|
|
const otherVel = velocities.element(i);
|
|
const dist = myPos.distance(otherPos);
|
|
|
|
If(dist.lessThan(2.0), () => {
|
|
// Separation
|
|
const diff = myPos.sub(otherPos).normalize().div(dist);
|
|
separation.addAssign(diff);
|
|
|
|
// Alignment
|
|
alignment.addAssign(otherVel);
|
|
|
|
// Cohesion
|
|
cohesion.addAssign(otherPos);
|
|
|
|
neighborCount.addAssign(1);
|
|
});
|
|
});
|
|
});
|
|
|
|
If(neighborCount.greaterThan(0), () => {
|
|
const n = neighborCount.toFloat();
|
|
alignment.divAssign(n);
|
|
cohesion.divAssign(n);
|
|
cohesion.assign(cohesion.sub(myPos));
|
|
|
|
myVel.addAssign(separation.mul(0.05));
|
|
myVel.addAssign(alignment.sub(myVel).mul(0.05));
|
|
myVel.addAssign(cohesion.mul(0.05));
|
|
});
|
|
|
|
// Limit speed
|
|
const speed = myVel.length();
|
|
If(speed.greaterThan(2.0), () => {
|
|
myVel.assign(myVel.normalize().mul(2.0));
|
|
});
|
|
|
|
myPos.addAssign(myVel.mul(deltaTimeUniform));
|
|
})().compute(count);
|
|
```
|
|
|
|
## Workgroups and Synchronization
|
|
|
|
### Workgroup Size
|
|
|
|
```javascript
|
|
// Default workgroup size is typically 64 or 256
|
|
const computeShader = Fn(() => {
|
|
// shader code
|
|
})().compute(count, { workgroupSize: 64 });
|
|
```
|
|
|
|
### Barriers
|
|
|
|
```javascript
|
|
import { workgroupBarrier, storageBarrier, textureBarrier } from 'three/tsl';
|
|
|
|
const computeShader = Fn(() => {
|
|
// Write data
|
|
sharedData.element(localIndex).assign(value);
|
|
|
|
// Ensure all workgroup threads reach this point
|
|
workgroupBarrier();
|
|
|
|
// Now safe to read data written by other threads
|
|
const neighborValue = sharedData.element(localIndex.add(1));
|
|
})().compute(count);
|
|
```
|
|
|
|
## Atomic Operations
|
|
|
|
For thread-safe read-modify-write operations:
|
|
|
|
```javascript
|
|
import { atomicAdd, atomicSub, atomicMax, atomicMin, atomicAnd, atomicOr, atomicXor } from 'three/tsl';
|
|
|
|
const counter = instancedArray(1, 'uint');
|
|
|
|
const computeShader = Fn(() => {
|
|
// Atomically increment counter
|
|
atomicAdd(counter.element(0), 1);
|
|
|
|
// Atomic max
|
|
atomicMax(maxValue.element(0), localValue);
|
|
})().compute(count);
|
|
```
|
|
|
|
## Using Compute Results in Materials
|
|
|
|
### Instanced Mesh with Computed Positions
|
|
|
|
```javascript
|
|
// Create instanced mesh
|
|
const geometry = new THREE.SphereGeometry(0.1, 16, 16);
|
|
const material = new THREE.MeshStandardNodeMaterial();
|
|
|
|
// Use computed positions
|
|
material.positionNode = positions.element(instanceIndex);
|
|
|
|
// Optionally use computed colors
|
|
material.colorNode = colors.element(instanceIndex);
|
|
|
|
const mesh = new THREE.InstancedMesh(geometry, material, count);
|
|
scene.add(mesh);
|
|
```
|
|
|
|
### Points with Computed Positions
|
|
|
|
```javascript
|
|
const geometry = new THREE.BufferGeometry();
|
|
geometry.setAttribute('position', new THREE.Float32BufferAttribute(new Float32Array(count * 3), 3));
|
|
|
|
const material = new THREE.PointsNodeMaterial();
|
|
material.positionNode = positions.element(instanceIndex);
|
|
material.colorNode = colors.element(instanceIndex);
|
|
material.sizeNode = float(5.0);
|
|
|
|
const points = new THREE.Points(geometry, material);
|
|
scene.add(points);
|
|
```
|
|
|
|
## Execution Methods
|
|
|
|
```javascript
|
|
// Synchronous compute (blocks until complete)
|
|
renderer.compute(computeShader);
|
|
|
|
// Asynchronous compute (returns promise)
|
|
await renderer.computeAsync(computeShader);
|
|
|
|
// Multiple computes
|
|
renderer.compute(computeInit);
|
|
renderer.compute(computePhysics);
|
|
renderer.compute(computeCollisions);
|
|
```
|
|
|
|
## Reading Back Data (GPU to CPU)
|
|
|
|
```javascript
|
|
// Create buffer for readback
|
|
const readBuffer = new Float32Array(count * 3);
|
|
|
|
// Read data back from GPU
|
|
await renderer.readRenderTargetPixelsAsync(
|
|
computeTexture,
|
|
0, 0, width, height,
|
|
readBuffer
|
|
);
|
|
```
|
|
|
|
## Complete Example: Particle System
|
|
|
|
```javascript
|
|
import * as THREE from 'three/webgpu';
|
|
import {
|
|
Fn, If, instancedArray, instanceIndex, uniform,
|
|
vec3, float, hash, time
|
|
} from 'three/tsl';
|
|
|
|
// Setup
|
|
const count = 50000;
|
|
const positions = instancedArray(count, 'vec3');
|
|
const velocities = instancedArray(count, 'vec3');
|
|
const lifetimes = instancedArray(count, 'float');
|
|
|
|
// Uniforms
|
|
const emitterPos = uniform(new THREE.Vector3(0, 0, 0));
|
|
const gravity = uniform(-2.0);
|
|
const dt = uniform(0);
|
|
|
|
// Initialize
|
|
const computeInit = Fn(() => {
|
|
const pos = positions.element(instanceIndex);
|
|
const vel = velocities.element(instanceIndex);
|
|
const life = lifetimes.element(instanceIndex);
|
|
|
|
pos.assign(emitterPos);
|
|
|
|
// Random velocity in cone
|
|
const angle = hash(instanceIndex).mul(Math.PI * 2);
|
|
const speed = hash(instanceIndex.add(1)).mul(2).add(1);
|
|
vel.x.assign(angle.cos().mul(speed).mul(0.3));
|
|
vel.y.assign(speed);
|
|
vel.z.assign(angle.sin().mul(speed).mul(0.3));
|
|
|
|
// Random lifetime
|
|
life.assign(hash(instanceIndex.add(2)).mul(2).add(1));
|
|
})().compute(count);
|
|
|
|
// Update
|
|
const computeUpdate = Fn(() => {
|
|
const pos = positions.element(instanceIndex);
|
|
const vel = velocities.element(instanceIndex);
|
|
const life = lifetimes.element(instanceIndex);
|
|
|
|
// Apply gravity
|
|
vel.y.addAssign(gravity.mul(dt));
|
|
|
|
// Update position
|
|
pos.addAssign(vel.mul(dt));
|
|
|
|
// Decrease lifetime
|
|
life.subAssign(dt);
|
|
|
|
// Respawn dead particles
|
|
If(life.lessThan(0), () => {
|
|
pos.assign(emitterPos);
|
|
const angle = hash(instanceIndex.add(time.mul(1000))).mul(Math.PI * 2);
|
|
const speed = hash(instanceIndex.add(time.mul(1000)).add(1)).mul(2).add(1);
|
|
vel.x.assign(angle.cos().mul(speed).mul(0.3));
|
|
vel.y.assign(speed);
|
|
vel.z.assign(angle.sin().mul(speed).mul(0.3));
|
|
life.assign(hash(instanceIndex.add(time.mul(1000)).add(2)).mul(2).add(1));
|
|
});
|
|
})().compute(count);
|
|
|
|
// Material
|
|
const material = new THREE.PointsNodeMaterial();
|
|
material.positionNode = positions.element(instanceIndex);
|
|
material.sizeNode = float(3.0);
|
|
material.colorNode = vec3(1, 0.5, 0.2);
|
|
|
|
// Geometry (dummy positions)
|
|
const geometry = new THREE.BufferGeometry();
|
|
geometry.setAttribute('position', new THREE.Float32BufferAttribute(new Float32Array(count * 3), 3));
|
|
|
|
const points = new THREE.Points(geometry, material);
|
|
scene.add(points);
|
|
|
|
// Init
|
|
await renderer.computeAsync(computeInit);
|
|
|
|
// Animation loop
|
|
function animate() {
|
|
dt.value = Math.min(clock.getDelta(), 0.1);
|
|
renderer.compute(computeUpdate);
|
|
renderer.render(scene, camera);
|
|
}
|
|
```
|