Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { listCommand } from './commands/list.js'
import { removeCommand } from './commands/remove.js'
import { rollbackCommand } from './commands/rollback.js'
import { repairCommand } from './commands/repair.js'
import { scanCommand } from './commands/scan.js'
import { setupCommand } from './commands/setup.js'

async function main(): Promise<void> {
Expand All @@ -19,6 +20,7 @@ async function main(): Promise<void> {
.command(rollbackCommand)
.command(removeCommand)
.command(listCommand)
.command(scanCommand)
.command(setupCommand)
.command(repairCommand)
.demandCommand(1, 'You must specify a command')
Expand Down
42 changes: 20 additions & 22 deletions src/commands/apply.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@ import {
PatchManifestSchema,
DEFAULT_PATCH_MANIFEST_PATH,
} from '../schema/manifest-schema.js'
import {
findNodeModules,
findPackagesForPatches,
applyPackagePatch,
} from '../patch/apply.js'
import { applyPackagePatch } from '../patch/apply.js'
import type { ApplyResult } from '../patch/apply.js'
import {
cleanupUnusedBlobs,
Expand All @@ -20,7 +16,7 @@ import {
fetchMissingBlobs,
formatFetchResult,
} from '../utils/blob-fetcher.js'
import { getGlobalPrefix } from '../utils/global-packages.js'
import { NpmCrawler } from '../crawlers/index.js'
import {
trackPatchApplied,
trackPatchApplyFailed,
Expand Down Expand Up @@ -98,22 +94,23 @@ async function applyPatches(
}
}

// Find node_modules directories
// Find node_modules directories using the crawler
const crawler = new NpmCrawler()
let nodeModulesPaths: string[]
if (useGlobal || globalPrefix) {
try {
nodeModulesPaths = [getGlobalPrefix(globalPrefix)]
if (!silent) {
console.log(`Using global npm packages at: ${nodeModulesPaths[0]}`)
}
} catch (error) {
if (!silent) {
console.error('Failed to find global npm packages:', error instanceof Error ? error.message : String(error))
}
return { success: false, results: [] }
try {
nodeModulesPaths = await crawler.getNodeModulesPaths({
cwd,
global: useGlobal,
globalPrefix,
})
if ((useGlobal || globalPrefix) && !silent && nodeModulesPaths.length > 0) {
console.log(`Using global npm packages at: ${nodeModulesPaths[0]}`)
}
} else {
nodeModulesPaths = await findNodeModules(cwd)
} catch (error) {
if (!silent) {
console.error('Failed to find npm packages:', error instanceof Error ? error.message : String(error))
}
return { success: false, results: [] }
}

if (nodeModulesPaths.length === 0) {
Expand All @@ -123,10 +120,11 @@ async function applyPatches(
return { success: false, results: [] }
}

// Find all packages that need patching
// Find all packages that need patching using the crawler
const manifestPurls = Object.keys(manifest.patches)
const allPackages = new Map<string, string>()
for (const nmPath of nodeModulesPaths) {
const packages = await findPackagesForPatches(nmPath, manifest)
const packages = await crawler.findByPurls(nmPath, manifestPurls)
for (const [purl, location] of packages) {
if (!allPackages.has(purl)) {
allPackages.set(purl, location.path)
Expand Down
160 changes: 93 additions & 67 deletions src/commands/get.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,9 @@ import {
cleanupUnusedBlobs,
formatCleanupResult,
} from '../utils/cleanup-blobs.js'
import {
enumerateNodeModules,
type EnumeratedPackage,
} from '../utils/enumerate-packages.js'
import { NpmCrawler, type CrawledPackage } from '../crawlers/index.js'
import { fuzzyMatchPackages, isPurl } from '../utils/fuzzy-match.js'
import { getGlobalPrefix } from '../utils/global-packages.js'
import {
findNodeModules,
findPackagesForPatches,
applyPackagePatch,
} from '../patch/apply.js'
import { applyPackagePatch } from '../patch/apply.js'
import { rollbackPackagePatch } from '../patch/rollback.js'
import {
getMissingBlobs,
Expand All @@ -36,7 +28,7 @@ import {
/**
* Represents a package that has available patches with CVE information
*/
interface PackageWithPatchInfo extends EnumeratedPackage {
interface PackageWithPatchInfo extends CrawledPackage {
/** Available patches for this package */
patches: PatchSearchResult[]
/** Whether user can access paid patches */
Expand Down Expand Up @@ -129,7 +121,7 @@ async function findInstalledPurls(
async function findPackagesWithPatches(
apiClient: APIClient,
orgSlug: string | null,
packages: EnumeratedPackage[],
packages: CrawledPackage[],
onProgress?: (checked: number, total: number, current: string) => void,
): Promise<PackageWithPatchInfo[]> {
const packagesWithPatches: PackageWithPatchInfo[] = []
Expand All @@ -152,20 +144,16 @@ async function findPackagesWithPatches(

const { patches, canAccessPaidPatches } = searchResponse

// Filter to only accessible patches
const accessiblePatches = patches.filter(
patch => patch.tier === 'free' || canAccessPaidPatches,
)

if (accessiblePatches.length === 0) {
// Include all patches (free and paid) - we'll show upgrade CTA for paid patches
if (patches.length === 0) {
continue
}

// Extract CVE and GHSA IDs from patches
// Extract CVE and GHSA IDs from all patches
const cveIds = new Set<string>()
const ghsaIds = new Set<string>()

for (const patch of accessiblePatches) {
for (const patch of patches) {
for (const [vulnId, vulnInfo] of Object.entries(patch.vulnerabilities)) {
// Check if the vulnId itself is a GHSA
if (GHSA_PATTERN.test(vulnId)) {
Expand All @@ -185,7 +173,7 @@ async function findPackagesWithPatches(

packagesWithPatches.push({
...pkg,
patches: accessiblePatches,
patches,
canAccessPaidPatches,
cveIds: Array.from(cveIds).sort(),
ghsaIds: Array.from(ghsaIds).sort(),
Expand Down Expand Up @@ -272,6 +260,10 @@ async function promptSelectPackageWithPatches(
? `${vulnIds.slice(0, 3).join(', ')} (+${vulnIds.length - 3} more)`
: vulnIds.join(', ')

// Count free vs paid patches
const freePatches = pkg.patches.filter(p => p.tier === 'free').length
const paidPatches = pkg.patches.filter(p => p.tier === 'paid').length

// Count patches and show severity info
const severities = new Set<string>()
for (const patch of pkg.patches) {
Expand All @@ -284,8 +276,18 @@ async function promptSelectPackageWithPatches(
return order.indexOf(a.toLowerCase()) - order.indexOf(b.toLowerCase())
})

// Build patch count string
let patchCountStr = String(freePatches)
if (paidPatches > 0) {
if (pkg.canAccessPaidPatches) {
patchCountStr += `+${paidPatches}`
} else {
patchCountStr += `+\x1b[33m${paidPatches} paid\x1b[0m`
}
}

console.log(` ${i + 1}. ${displayName}@${pkg.version}`)
console.log(` Patches: ${pkg.patches.length} | Severity: ${severityList.join(', ')}`)
console.log(` Patches: ${patchCountStr} | Severity: ${severityList.join(', ')}`)
console.log(` Fixes: ${vulnSummary}`)
}

Expand Down Expand Up @@ -453,19 +455,20 @@ async function applyDownloadedPatches(
}
}

// Find node_modules directories
// Find node_modules directories using the crawler
const crawler = new NpmCrawler()
let nodeModulesPaths: string[]
if (useGlobal || globalPrefix) {
try {
nodeModulesPaths = [getGlobalPrefix(globalPrefix)]
} catch (error) {
if (!silent) {
console.error('Failed to find global npm packages:', error instanceof Error ? error.message : String(error))
}
return false
try {
nodeModulesPaths = await crawler.getNodeModulesPaths({
cwd,
global: useGlobal,
globalPrefix,
})
} catch (error) {
if (!silent) {
console.error('Failed to find npm packages:', error instanceof Error ? error.message : String(error))
}
} else {
nodeModulesPaths = await findNodeModules(cwd)
return false
}

if (nodeModulesPaths.length === 0) {
Expand All @@ -475,10 +478,11 @@ async function applyDownloadedPatches(
return false
}

// Find all packages that need patching
// Find all packages that need patching using the crawler
const manifestPurls = Object.keys(manifest.patches)
const allPackages = new Map<string, string>()
for (const nmPath of nodeModulesPaths) {
const packages = await findPackagesForPatches(nmPath, manifest)
const packages = await crawler.findByPurls(nmPath, manifestPurls)
for (const [purl, location] of packages) {
if (!allPackages.has(purl)) {
allPackages.set(purl, location.path)
Expand Down Expand Up @@ -551,19 +555,21 @@ async function applyOneOffPatch(
silent: boolean,
globalPrefix?: string,
): Promise<{ success: boolean; rollback?: () => Promise<void> }> {
// Find the package location
// Find the package location using the crawler
const crawler = new NpmCrawler()
let nodeModulesPath: string
if (useGlobal || globalPrefix) {
try {
nodeModulesPath = getGlobalPrefix(globalPrefix)
} catch (error) {
if (!silent) {
console.error('Failed to find global npm packages:', error instanceof Error ? error.message : String(error))
}
return { success: false }
try {
const paths = await crawler.getNodeModulesPaths({
cwd,
global: useGlobal,
globalPrefix,
})
nodeModulesPath = paths[0] ?? path.join(cwd, 'node_modules')
} catch (error) {
if (!silent) {
console.error('Failed to find npm packages:', error instanceof Error ? error.message : String(error))
}
} else {
nodeModulesPath = path.join(cwd, 'node_modules')
return { success: false }
}

// Parse PURL to get package directory
Expand Down Expand Up @@ -718,19 +724,23 @@ async function getPatches(args: GetArgs): Promise<boolean> {
// The org slug to use (null when using public proxy)
const effectiveOrgSlug = usePublicProxy ? null : orgSlug ?? null

// Determine node_modules path for package lookups
// Determine node_modules path for package lookups using the crawler
const crawler = new NpmCrawler()
let nodeModulesPath: string
if (useGlobal || globalPrefix) {
try {
nodeModulesPath = getGlobalPrefix(globalPrefix)
try {
const paths = await crawler.getNodeModulesPaths({
cwd,
global: useGlobal,
globalPrefix,
})
nodeModulesPath = paths[0] ?? path.join(cwd, 'node_modules')
if (useGlobal || globalPrefix) {
console.log(`Using global npm packages at: ${nodeModulesPath}`)
} catch (error) {
throw new Error(
`Failed to find global npm packages: ${error instanceof Error ? error.message : String(error)}`,
)
}
} else {
nodeModulesPath = path.join(cwd, 'node_modules')
} catch (error) {
throw new Error(
`Failed to find npm packages: ${error instanceof Error ? error.message : String(error)}`,
)
}

// Determine identifier type
Expand Down Expand Up @@ -765,6 +775,18 @@ async function getPatches(args: GetArgs): Promise<boolean> {
return true
}

// Check if patch is paid and user doesn't have access
if (patch.tier === 'paid' && usePublicProxy) {
console.log(`\n\x1b[33m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m`)
console.log(`\x1b[33m This patch requires a paid subscription to download.\x1b[0m`)
console.log(`\x1b[33m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m`)
console.log(`\n Patch: ${patch.purl}`)
console.log(` Tier: \x1b[33mpaid\x1b[0m`)
console.log(`\n Upgrade to Socket's paid plan to access this patch and many more:`)
console.log(` \x1b[36mhttps://socket.dev/pricing\x1b[0m\n`)
return true
}

// Handle one-off mode
if (oneOff) {
const { success, rollback } = await applyOneOffPatch(patch, useGlobal ?? false, cwd, false, globalPrefix)
Expand Down Expand Up @@ -838,12 +860,13 @@ async function getPatches(args: GetArgs): Promise<boolean> {
break
}
case 'package': {
// Enumerate packages from node_modules and fuzzy match
const enumPath = useGlobal ? nodeModulesPath : cwd
console.log(`Enumerating packages in ${enumPath}...`)
const packages = useGlobal
? await enumerateNodeModules(path.dirname(nodeModulesPath))
: await enumerateNodeModules(cwd)
// Enumerate packages from node_modules using the crawler and fuzzy match
console.log(`Enumerating packages in ${nodeModulesPath}...`)
const packages = await crawler.crawlAll({
cwd,
global: useGlobal,
globalPrefix,
})

if (packages.length === 0) {
console.log(useGlobal
Expand Down Expand Up @@ -997,16 +1020,19 @@ async function getPatches(args: GetArgs): Promise<boolean> {
console.log(`Note: ${notInstalledCount} patch(es) for packages not installed in this project were hidden.`)
}

if (inaccessibleCount > 0) {
if (inaccessibleCount > 0 && !canAccessPaidPatches) {
console.log(
`Note: ${inaccessibleCount} patch(es) require paid access and will be skipped.`,
`\x1b[33mNote: ${inaccessibleCount} patch(es) require a paid subscription and will be skipped.\x1b[0m`,
)
}

if (accessiblePatches.length === 0) {
console.log(
'No accessible patches available. Upgrade to access paid patches.',
)
console.log(`\n\x1b[33m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m`)
console.log(`\x1b[33m All available patches require a paid subscription.\x1b[0m`)
console.log(`\x1b[33m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\x1b[0m`)
console.log(`\n Found ${inaccessibleCount} paid patch(es) that you cannot currently access.`)
console.log(`\n Upgrade to Socket's paid plan to access these patches:`)
console.log(` \x1b[36mhttps://socket.dev/pricing\x1b[0m\n`)
return true
}

Expand Down
Loading