Skip to content

Make snippets more robust, correct line numbers #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions blog/2024-11-21-optimizing-matrix-mul/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ import { WebGpuKernel } from './snippets/naive.tsx';
With Rust GPU, we specify the inputs as arguments to the kernel and configure them with
[procedural macros](https://doc.Rust-lang.org/reference/procedural-macros.html):

import { RustNaiveInputs } from './snippets/naive.tsx';
import { RustNaiveKernel } from './snippets/naive.tsx';

<RustNaiveInputs/>
<RustNaiveKernel/>

This code looks like normal Rust code but _runs entirely on the GPU._

Expand Down Expand Up @@ -301,6 +301,13 @@ improvement over the last kernel.
To stay true to the spirit of Zach's original blog post, we'll wrap things up here and
leave the "fancier" experiments for another time.

### A note on performance

I didn't include performance numbers as I have a different machine than Zach. The
complete runnable code can be [found on
GitHub](https://github.com/Rust-GPU/rust-gpu.github.io/tree/main/blog/2024-11-21-optimizing-matrix-mul/code)
and you can run the benchmarks yourself with `cargo bench`.

## Reflections on porting to Rust GPU

Porting to Rust GPU went quickly, as the kernels Zach used were fairly simple. Most of
Expand Down
9 changes: 5 additions & 4 deletions blog/2024-11-21-optimizing-matrix-mul/snippets/naive.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
</CodeBlock>
);

export const RustNaiveInputs: React.FC = () => (
export const RustNaiveKernel: React.FC = () => (
<Snippet
language="rust"
className="text-xs"
metastring="1-5,7"
showLineNumbers
title="Naive kernel with Rust GPU"
>
Expand All @@ -59,6 +58,7 @@ export const RustNaiveWorkgroupCount: React.FC = () => (
language="rust"
className="text-xs"
lines="26-34"
hash="8abb43d"
title="Calculating on the CPU how many workgroup dispatches are needed"
>
{RustWorkgroupCount}
Expand All @@ -69,7 +69,8 @@ export const RustNaiveDispatch: React.FC = () => (
<Snippet
language="rust"
className="text-xs"
lines="145,147"
lines="152,154"
hash="cbb5295"
strip_leading_spaces
title="Using wgpu on the CPU to dispatch workgroups to the GPU"
>
Expand All @@ -78,7 +79,7 @@ export const RustNaiveDispatch: React.FC = () => (
);

export const RustNaiveWorkgroup: React.FC = () => (
<Snippet language="rust" className="text-xs" lines="7">
<Snippet language="rust" className="text-xs" lines="7" hash="7762339">
{RustKernelSource}
</Snippet>
);
19 changes: 13 additions & 6 deletions blog/2024-11-21-optimizing-matrix-mul/snippets/party.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import RustWgpuBackend from "!!raw-loader!../code/crates/cpu/matmul/src/backends
import RustCpuBackendSource from "!!raw-loader!../code/crates/cpu/matmul/src/backends/cpu.rs";

export const RustPartySettings: React.FC = () => (
<Snippet language="rust" className="text-xs" lines="3,9,11">
<Snippet language="rust" className="text-xs" lines="3,9,11" hash="47bb656">
{RustKernelSource}
</Snippet>
);
Expand All @@ -19,21 +19,28 @@ export const RustIsomorphic: React.FC = () => (
);

export const RustIsomorphicGlam: React.FC = () => (
<Snippet language="rust" lines="15-19" className="text-xs">
<Snippet language="rust" lines="15-19" hash="a3dbf2f" className="text-xs">
{RustIsomorphicSource}
</Snippet>
);

export const RustIsomorphicDeps: React.FC = () => (
<Snippet language="rust" lines="9-20" className="text-xs" title="Cargo.toml">
<Snippet
language="rust"
lines="9-20"
hash="72c14d7"
className="text-xs"
title="Cargo.toml"
>
{RustIsomorphicCargoToml}
</Snippet>
);

export const RustWgpuDimensions: React.FC = () => (
<Snippet
language="rust"
lines="98-111"
lines="108-118"
hash="cbb5295"
className="text-xs"
title="Creating the Dimensions struct on the CPU and writing it to the GPU"
>
Expand All @@ -42,13 +49,13 @@ export const RustWgpuDimensions: React.FC = () => (
);

export const RustCpuBackendHarness: React.FC = () => (
<Snippet language="rust" className="text-xs" lines="30-72">
<Snippet language="rust" className="text-xs" lines="30-79" hash="7ad7cab">
{RustCpuBackendSource}
</Snippet>
);

export const RustCpuBackendTest: React.FC = () => (
<Snippet language="rust" className="text-xs" lines="155-172">
<Snippet language="rust" className="text-xs" lines="174-194" hash="7ad7cab">
{RustCpuBackendSource}
</Snippet>
);
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import RustKernelSource from "!!raw-loader!../code/crates/gpu/workgroup_256/src/
import VariantsSource from "!!raw-loader!../code/crates/cpu/matmul/src/variants.rs";

export const RustWorkgroup256Workgroup: React.FC = () => (
<Snippet language="rust" className="text-xs" lines="7">
<Snippet language="rust" className="text-xs" lines="7" hash="56b3ae8">
{RustKernelSource}
</Snippet>
);
Expand All @@ -13,7 +13,8 @@ export const RustWorkgroup256WorkgroupCount: React.FC = () => (
<Snippet
language="rust"
className="text-xs"
lines="51-64"
lines="51-65"
hash="8abb43d"
title="Calculating how many workgroup dispatches are needed on the CPU"
>
{VariantsSource}
Expand Down
32 changes: 1 addition & 31 deletions blog/2024-11-21-optimizing-matrix-mul/snippets/workgroup_2d.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,44 +15,14 @@ export const RustWorkgroup2d: React.FC = () => (
</Snippet>
);

/*
export const RustWorkgroup2d: React.FC = () => (
<Snippet
language="rust"
className="text-xs"
lines="7-8,15-16"
title="2D workgroup kernel with Rust GPU"
>
{RustKernelSource}
</Snippet>
);
*/

export const RustWorkgroup2dWorkgroup: React.FC = () => (
<Snippet language="rust" className="text-xs" lines="7">
{RustKernelSource}
</Snippet>
);

export const RustWorkgroup2dWorkgroupCount: React.FC = () => (
<Snippet
language="rust"
className="text-xs"
lines="82-94"
hash="8abb43d"
title="Calculating how many workgroup dispatches are needed on the CPU"
>
{VariantsSource}
</Snippet>
);

export const RustWorkgroup2dWgpuDispatch: React.FC = () => (
<Snippet
language="rust"
className="text-xs"
lines="144,145,147"
strip_leading_spaces
title="Using wgpu on the CPU to dispatch to the GPU"
>
{WgpuBackendSource}
</Snippet>
);
75 changes: 66 additions & 9 deletions src/components/Snippet/index.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import React from "react";
import React, { useEffect, useState } from "react";
import CodeBlock from "@theme/CodeBlock";

interface SnippetProps extends React.ComponentProps<typeof CodeBlock> {
Expand All @@ -8,27 +8,73 @@ interface SnippetProps extends React.ComponentProps<typeof CodeBlock> {
lines?: string;
omitted_placeholder?: string;
strip_leading_spaces?: boolean;
/**
* Optional short hash of the content (first N characters of SHA-256),
* required only when `lines` is specified.
*/
hash?: string;
}

/**
* A component for rendering a snippet of code, optionally filtering lines,
* showing ellipses for omissions, and stripping all leading spaces.
* showing ellipses for omissions, stripping leading spaces, and validating hash.
*/
const Snippet: React.FC<SnippetProps> = ({
children,
lines,
omitted_placeholder = "...",
strip_leading_spaces = false,
hash,
...props
}) => {
const [error, setError] = useState<string | null>(null);

if (typeof children !== "string") {
console.error(
throw new Error(
"Snippet expects children to be a string containing the file content."
);
return null;
}

// Parse the `linesToInclude` metadata string into an array of line numbers.
/**
* Utility function to compute the SHA-256 hash of a string.
* @param content The input string
* @returns Promise resolving to a hex-encoded hash
*/
const computeHash = async (content: string): Promise<string> => {
const encoder = new TextEncoder();
const data = encoder.encode(content);
const hashBuffer = await crypto.subtle.digest("SHA-256", data);
return Array.from(new Uint8Array(hashBuffer))
.map((byte) => byte.toString(16).padStart(2, "0"))
.join("");
};

useEffect(() => {
if (lines) {
computeHash(children).then((computedHash) => {
const shortHash = computedHash.slice(0, 7); // Use 7 characters for the short hash

if (!hash) {
setError(
`The \`hash\` prop is required when \`lines\` is specified.\n` +
`Provide the following hash as the \`hash\` prop: ${shortHash}`
);
} else if (shortHash !== hash) {
setError(
`Snippet hash mismatch.\n` +
`Specified: ${hash}, but content is: ${shortHash} (full hash: ${computedHash}).\n` +
`Check if the line numbers are still relevant and update the hash.`
);
}
});
}
}, [children, lines, hash]);

if (error) {
throw new Error(error);
}

// Parse the `lines` metadata string into an array of line numbers.
const parseLineRanges = (metaString?: string): number[] => {
if (!metaString) return [];
return metaString.split(",").flatMap((range) => {
Expand All @@ -46,16 +92,27 @@ const Snippet: React.FC<SnippetProps> = ({
if (lines.length === 0) return content; // If no specific lines are specified, return full content.

const includedContent: string[] = [];

// Filter lines and find the minimum indentation
const selectedLines = lines
.map((line) => allLines[line - 1] || "")
.filter((line) => line.trim().length > 0); // Ignore blank lines

const minIndent = selectedLines.reduce((min, line) => {
const indentMatch = line.match(/^(\s*)\S/);
const indentLength = indentMatch ? indentMatch[1].length : 0;
return Math.min(min, indentLength);
}, Infinity);

lines.forEach((line, index) => {
if (index > 0 && lines[index - 1] < line - 1) {
includedContent.push(omitted_placeholder); // Add placeholder for omitted lines
}

const rawLine = allLines[line - 1] || "";
const formattedLine = strip_leading_spaces
? rawLine.trimStart()
: rawLine;
includedContent.push(formattedLine);
const trimmedLine =
rawLine.trim().length > 0 ? rawLine.slice(minIndent) : rawLine;
includedContent.push(trimmedLine);
});

// Add placeholder if lines at the end are omitted
Expand Down