diff --git a/cmd/files_test.go b/cmd/files_test.go
new file mode 100644
index 0000000..3cf875e
--- /dev/null
+++ b/cmd/files_test.go
@@ -0,0 +1,147 @@
+package cmd
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/tnypxl/rollup/internal/config"
+)
+
+func TestMatchGlob(t *testing.T) {
+ tests := []struct {
+ pattern string
+ path string
+ expected bool
+ }{
+ {"*.go", "file.go", true},
+ {"*.go", "file.txt", false},
+ {"**/*.go", "dir/file.go", true},
+ {"**/*.go", "dir/subdir/file.go", true},
+ {"dir/*.go", "dir/file.go", true},
+ {"dir/*.go", "otherdir/file.go", false},
+ }
+
+ for _, test := range tests {
+ result := matchGlob(test.pattern, test.path)
+ if result != test.expected {
+ t.Errorf("matchGlob(%q, %q) = %v; want %v", test.pattern, test.path, result, test.expected)
+ }
+ }
+}
+
+func TestIsCodeGenerated(t *testing.T) {
+ patterns := []string{"generated_*.go", "**/auto_*.go"}
+ tests := []struct {
+ path string
+ expected bool
+ }{
+ {"generated_file.go", true},
+ {"normal_file.go", false},
+ {"subdir/auto_file.go", true},
+ {"subdir/normal_file.go", false},
+ }
+
+ for _, test := range tests {
+ result := isCodeGenerated(test.path, patterns)
+ if result != test.expected {
+ t.Errorf("isCodeGenerated(%q, %v) = %v; want %v", test.path, patterns, result, test.expected)
+ }
+ }
+}
+
+func TestIsIgnored(t *testing.T) {
+ patterns := []string{"*.tmp", "**/*.log"}
+ tests := []struct {
+ path string
+ expected bool
+ }{
+ {"file.tmp", true},
+ {"file.go", false},
+ {"subdir/file.log", true},
+ {"subdir/file.txt", false},
+ }
+
+ for _, test := range tests {
+ result := isIgnored(test.path, patterns)
+ if result != test.expected {
+ t.Errorf("isIgnored(%q, %v) = %v; want %v", test.path, patterns, result, test.expected)
+ }
+ }
+}
+
+func TestRunRollup(t *testing.T) {
+ // Create a temporary directory for testing
+ tempDir, err := os.MkdirTemp("", "rollup_test")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tempDir)
+
+ // Create some test files
+ files := map[string]string{
+ "file1.go": "package main\n\nfunc main() {}\n",
+ "file2.txt": "This is a text file.\n",
+ "subdir/file3.go": "package subdir\n\nfunc Func() {}\n",
+ "subdir/file4.json": "{\"key\": \"value\"}\n",
+ }
+
+ for name, content := range files {
+ path := filepath.Join(tempDir, name)
+ err := os.MkdirAll(filepath.Dir(path), 0755)
+ if err != nil {
+ t.Fatalf("Failed to create directory: %v", err)
+ }
+ err = os.WriteFile(path, []byte(content), 0644)
+ if err != nil {
+ t.Fatalf("Failed to write file: %v", err)
+ }
+ }
+
+ // Set up test configuration
+ cfg = &config.Config{
+ FileTypes: []string{"go", "txt"},
+ Ignore: []string{"*.json"},
+ CodeGenerated: []string{},
+ }
+ path = tempDir
+
+ // Run the rollup
+ err = runRollup()
+ if err != nil {
+ t.Fatalf("runRollup() failed: %v", err)
+ }
+
+ // Check if the output file was created
+ outputFiles, err := filepath.Glob(filepath.Join(tempDir, "*.rollup.md"))
+ if err != nil {
+ t.Fatalf("Failed to glob output files: %v", err)
+ }
+ if len(outputFiles) != 1 {
+ t.Fatalf("Expected 1 output file, got %d", len(outputFiles))
+ }
+
+ // Read the content of the output file
+ content, err := os.ReadFile(outputFiles[0])
+ if err != nil {
+ t.Fatalf("Failed to read output file: %v", err)
+ }
+
+ // Check if the content includes the expected files
+ expectedContent := []string{
+ "# File: file1.go",
+ "# File: file2.txt",
+ "# File: subdir/file3.go",
+ }
+ for _, expected := range expectedContent {
+ if !strings.Contains(string(content), expected) {
+ t.Errorf("Output file does not contain expected content: %s", expected)
+ }
+ }
+
+ // Check if the ignored file is not included
+ if strings.Contains(string(content), "file4.json") {
+ t.Errorf("Output file contains ignored file: file4.json")
+ }
+}
diff --git a/cmd/web.go b/cmd/web.go
index 26ac38c..f35418c 100644
--- a/cmd/web.go
+++ b/cmd/web.go
@@ -2,6 +2,8 @@ package cmd
import (
"fmt"
+ "io/ioutil"
+ "log"
"net/url"
"os"
"regexp"
@@ -9,6 +11,7 @@ import (
"time"
"github.com/spf13/cobra"
+ "github.com/tnypxl/rollup/internal/config"
"github.com/tnypxl/rollup/internal/scraper"
)
@@ -38,47 +41,93 @@ func init() {
}
func runWeb(cmd *cobra.Command, args []string) error {
- scraperConfig.Verbose = verbose
+ scraper.SetupLogger(verbose)
+ logger := log.New(os.Stdout, "WEB: ", log.LstdFlags)
+ if !verbose {
+ logger.SetOutput(ioutil.Discard)
+ }
+ logger.Printf("Starting web scraping process with verbose mode: %v", verbose)
+ scraperConfig.Verbose = verbose
- // Use config if available, otherwise use command-line flags
- var urlConfigs []scraper.URLConfig
- if len(urls) == 0 && len(cfg.Scrape.URLs) > 0 {
- urlConfigs = make([]scraper.URLConfig, len(cfg.Scrape.URLs))
- for i, u := range cfg.Scrape.URLs {
- urlConfigs[i] = scraper.URLConfig{
- URL: u.URL,
- CSSLocator: u.CSSLocator,
- ExcludeSelectors: u.ExcludeSelectors,
- OutputAlias: u.OutputAlias,
- }
- }
- } else {
- urlConfigs = make([]scraper.URLConfig, len(urls))
- for i, u := range urls {
- urlConfigs[i] = scraper.URLConfig{URL: u, CSSLocator: includeSelector}
- }
- }
+ var siteConfigs []scraper.SiteConfig
+ if len(cfg.Scrape.Sites) > 0 {
+ logger.Printf("Using configuration from rollup.yml for %d sites", len(cfg.Scrape.Sites))
+ siteConfigs = make([]scraper.SiteConfig, len(cfg.Scrape.Sites))
+ for i, site := range cfg.Scrape.Sites {
+ siteConfigs[i] = scraper.SiteConfig{
+ BaseURL: site.BaseURL,
+ CSSLocator: site.CSSLocator,
+ ExcludeSelectors: site.ExcludeSelectors,
+ MaxDepth: site.MaxDepth,
+ AllowedPaths: site.AllowedPaths,
+ ExcludePaths: site.ExcludePaths,
+ OutputAlias: site.OutputAlias,
+ PathOverrides: convertPathOverrides(site.PathOverrides),
+ }
+ logger.Printf("Site %d configuration: BaseURL=%s, CSSLocator=%s, MaxDepth=%d, AllowedPaths=%v",
+ i+1, site.BaseURL, site.CSSLocator, site.MaxDepth, site.AllowedPaths)
+ }
+ } else {
+ logger.Printf("No sites defined in rollup.yml, falling back to URL-based configuration")
+ siteConfigs = make([]scraper.SiteConfig, len(urls))
+ for i, u := range urls {
+ siteConfigs[i] = scraper.SiteConfig{
+ BaseURL: u,
+ CSSLocator: includeSelector,
+ ExcludeSelectors: excludeSelectors,
+ MaxDepth: depth,
+ }
+ logger.Printf("URL %d configuration: BaseURL=%s, CSSLocator=%s, MaxDepth=%d",
+ i+1, u, includeSelector, depth)
+ }
+ }
- if len(urlConfigs) == 0 {
- return fmt.Errorf("no URLs provided. Use --urls flag with comma-separated URLs or set 'scrape.urls' in the rollup.yml file")
- }
+ if len(siteConfigs) == 0 {
+ logger.Println("Error: No sites or URLs provided")
+ return fmt.Errorf("no sites or URLs provided. Use --urls flag with comma-separated URLs or set 'scrape.sites' in the rollup.yml file")
+ }
- scraperConfig := scraper.Config{
- URLs: urlConfigs,
- OutputType: outputType,
- Verbose: verbose,
- }
+ // Set default values for rate limiting
+ defaultRequestsPerSecond := 1.0
+ defaultBurstLimit := 3
- scrapedContent, err := scraper.ScrapeMultipleURLs(scraperConfig)
- if err != nil {
- return fmt.Errorf("error scraping content: %v", err)
- }
+ // Use default values if not set in the configuration
+ requestsPerSecond := cfg.Scrape.RequestsPerSecond
+ if requestsPerSecond == 0 {
+ requestsPerSecond = defaultRequestsPerSecond
+ }
+ burstLimit := cfg.Scrape.BurstLimit
+ if burstLimit == 0 {
+ burstLimit = defaultBurstLimit
+ }
- if outputType == "single" {
- return writeSingleFile(scrapedContent)
- } else {
- return writeMultipleFiles(scrapedContent)
- }
+ scraperConfig := scraper.Config{
+ Sites: siteConfigs,
+ OutputType: outputType,
+ Verbose: verbose,
+ Scrape: scraper.ScrapeConfig{
+ RequestsPerSecond: requestsPerSecond,
+ BurstLimit: burstLimit,
+ },
+ }
+ logger.Printf("Scraper configuration: OutputType=%s, RequestsPerSecond=%f, BurstLimit=%d",
+ outputType, requestsPerSecond, burstLimit)
+
+ logger.Println("Starting scraping process")
+ scrapedContent, err := scraper.ScrapeSites(scraperConfig)
+ if err != nil {
+ logger.Printf("Error occurred during scraping: %v", err)
+ return fmt.Errorf("error scraping content: %v", err)
+ }
+ logger.Printf("Scraping completed. Total content scraped: %d", len(scrapedContent))
+
+ if outputType == "single" {
+ logger.Println("Writing content to a single file")
+ return writeSingleFile(scrapedContent)
+ } else {
+ logger.Println("Writing content to multiple files")
+ return writeMultipleFiles(scrapedContent)
+ }
}
func writeSingleFile(content map[string]string) error {
@@ -102,20 +151,26 @@ func writeSingleFile(content map[string]string) error {
func writeMultipleFiles(content map[string]string) error {
for url, c := range content {
- filename := getFilenameFromContent(c, url)
+ filename, err := getFilenameFromContent(c, url)
+ if err != nil {
+ return fmt.Errorf("error generating filename for %s: %v", url, err)
+ }
+
file, err := os.Create(filename)
if err != nil {
return fmt.Errorf("error creating output file %s: %v", filename, err)
}
- _, err = fmt.Fprintf(file, "# Content from %s\n\n%s", url, c)
- file.Close()
+ _, err = file.WriteString(fmt.Sprintf("# Content from %s\n\n%s\n", url, c))
if err != nil {
+ file.Close()
return fmt.Errorf("error writing content to file %s: %v", filename, err)
}
+ file.Close()
fmt.Printf("Content from %s has been saved to %s\n", url, filename)
}
+
return nil
}
@@ -136,13 +191,13 @@ func scrapeURL(urlStr string, depth int, visited map[string]bool) (string, error
visited[urlStr] = true
- content, err := extractAndConvertContent(urlStr)
+ content, err := testExtractAndConvertContent(urlStr)
if err != nil {
return "", err
}
if depth > 0 {
- links, err := scraper.ExtractLinks(urlStr)
+ links, err := testExtractLinks(urlStr)
if err != nil {
return content, fmt.Errorf("error extracting links: %v", err)
}
@@ -160,6 +215,9 @@ func scrapeURL(urlStr string, depth int, visited map[string]bool) (string, error
return content, nil
}
+var testExtractAndConvertContent = extractAndConvertContent
+var testExtractLinks = scraper.ExtractLinks
+
func extractAndConvertContent(urlStr string) (string, error) {
content, err := scraper.FetchWebpageContent(urlStr)
if err != nil {
@@ -187,17 +245,32 @@ func extractAndConvertContent(urlStr string) (string, error) {
return header + markdown + "\n\n", nil
}
-func getFilenameFromContent(content, url string) string {
+func getFilenameFromContent(content, urlStr string) (string, error) {
// Try to extract title from content
titleStart := strings.Index(content, "
")
titleEnd := strings.Index(content, "")
if titleStart != -1 && titleEnd != -1 && titleEnd > titleStart {
- title := content[titleStart+7 : titleEnd]
- return sanitizeFilename(title) + ".md"
+ title := strings.TrimSpace(content[titleStart+7 : titleEnd])
+ if title != "" {
+ return sanitizeFilename(title) + ".rollup.md", nil
+ }
}
- // If no title found, use the URL
- return sanitizeFilename(url) + ".md"
+ // If no title found or title is empty, use the URL
+ parsedURL, err := url.Parse(urlStr)
+ if err != nil {
+ return "", fmt.Errorf("invalid URL: %v", err)
+ }
+
+ if parsedURL.Host == "" {
+ return "", fmt.Errorf("invalid URL: missing host")
+ }
+
+ filename := parsedURL.Host
+ if parsedURL.Path != "" && parsedURL.Path != "/" {
+ filename += strings.TrimSuffix(parsedURL.Path, "/")
+ }
+ return sanitizeFilename(filename) + ".rollup.md", nil
}
func sanitizeFilename(name string) string {
@@ -215,3 +288,15 @@ func sanitizeFilename(name string) string {
return name
}
+
+func convertPathOverrides(configOverrides []config.PathOverride) []scraper.PathOverride {
+ scraperOverrides := make([]scraper.PathOverride, len(configOverrides))
+ for i, override := range configOverrides {
+ scraperOverrides[i] = scraper.PathOverride{
+ Path: override.Path,
+ CSSLocator: override.CSSLocator,
+ ExcludeSelectors: override.ExcludeSelectors,
+ }
+ }
+ return scraperOverrides
+}
diff --git a/cmd/web_test.go b/cmd/web_test.go
new file mode 100644
index 0000000..8e470be
--- /dev/null
+++ b/cmd/web_test.go
@@ -0,0 +1,154 @@
+package cmd
+
+import (
+ "testing"
+ "strings"
+ "github.com/tnypxl/rollup/internal/config"
+)
+
+func TestConvertPathOverrides(t *testing.T) {
+ configOverrides := []config.PathOverride{
+ {
+ Path: "/blog",
+ CSSLocator: "article",
+ ExcludeSelectors: []string{".ads", ".comments"},
+ },
+ {
+ Path: "/products",
+ CSSLocator: ".product-description",
+ ExcludeSelectors: []string{".related-items"},
+ },
+ }
+
+ scraperOverrides := convertPathOverrides(configOverrides)
+
+ if len(scraperOverrides) != len(configOverrides) {
+ t.Errorf("Expected %d overrides, got %d", len(configOverrides), len(scraperOverrides))
+ }
+
+ for i, override := range scraperOverrides {
+ if override.Path != configOverrides[i].Path {
+ t.Errorf("Expected Path %s, got %s", configOverrides[i].Path, override.Path)
+ }
+ if override.CSSLocator != configOverrides[i].CSSLocator {
+ t.Errorf("Expected CSSLocator %s, got %s", configOverrides[i].CSSLocator, override.CSSLocator)
+ }
+ if len(override.ExcludeSelectors) != len(configOverrides[i].ExcludeSelectors) {
+ t.Errorf("Expected %d ExcludeSelectors, got %d", len(configOverrides[i].ExcludeSelectors), len(override.ExcludeSelectors))
+ }
+ for j, selector := range override.ExcludeSelectors {
+ if selector != configOverrides[i].ExcludeSelectors[j] {
+ t.Errorf("Expected ExcludeSelector %s, got %s", configOverrides[i].ExcludeSelectors[j], selector)
+ }
+ }
+ }
+}
+
+func TestSanitizeFilename(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {"Hello, World!", "Hello_World"},
+ {"file/with/path", "file_with_path"},
+ {"file.with.dots", "file_with_dots"},
+ {"___leading_underscores___", "leading_underscores"},
+ {"", "untitled"},
+ {"!@#$%^&*()", "untitled"},
+ }
+
+ for _, test := range tests {
+ result := sanitizeFilename(test.input)
+ if result != test.expected {
+ t.Errorf("sanitizeFilename(%q) = %q; want %q", test.input, result, test.expected)
+ }
+ }
+}
+
+func TestGetFilenameFromContent(t *testing.T) {
+ tests := []struct {
+ content string
+ url string
+ expected string
+ expectErr bool
+ }{
+ {"Test Page", "http://example.com", "Test_Page.rollup.md", false},
+ {"No title here", "http://example.com/page", "example_com_page.rollup.md", false},
+ {" Trim Me ", "http://example.com", "Trim_Me.rollup.md", false},
+ {"", "http://example.com", "example_com.rollup.md", false},
+ {" ", "http://example.com", "example_com.rollup.md", false},
+ {"Invalid URL", "not a valid url", "", true},
+ {"No host", "http://", "", true},
+ }
+
+ for _, test := range tests {
+ result, err := getFilenameFromContent(test.content, test.url)
+ if test.expectErr {
+ if err == nil {
+ t.Errorf("getFilenameFromContent(%q, %q) expected an error, but got none", test.content, test.url)
+ }
+ } else {
+ if err != nil {
+ t.Errorf("getFilenameFromContent(%q, %q) unexpected error: %v", test.content, test.url, err)
+ }
+ if result != test.expected {
+ t.Errorf("getFilenameFromContent(%q, %q) = %q; want %q", test.content, test.url, result, test.expected)
+ }
+ }
+ }
+}
+
+// Mock functions for testing
+func mockExtractAndConvertContent(urlStr string) (string, error) {
+ return "Mocked content for " + urlStr, nil
+}
+
+func mockExtractLinks(urlStr string) ([]string, error) {
+ return []string{"http://example.com/link1", "http://example.com/link2"}, nil
+}
+
+func TestScrapeURL(t *testing.T) {
+ // Store the original functions
+ originalExtractAndConvertContent := testExtractAndConvertContent
+ originalExtractLinks := testExtractLinks
+
+ // Define mock functions
+ testExtractAndConvertContent = func(urlStr string) (string, error) {
+ return "Mocked content for " + urlStr, nil
+ }
+ testExtractLinks = func(urlStr string) ([]string, error) {
+ return []string{"http://example.com/link1", "http://example.com/link2"}, nil
+ }
+
+ // Defer the restoration of original functions
+ defer func() {
+ testExtractAndConvertContent = originalExtractAndConvertContent
+ testExtractLinks = originalExtractLinks
+ }()
+
+ tests := []struct {
+ url string
+ depth int
+ expectedCalls int
+ }{
+ {"http://example.com", 0, 1},
+ {"http://example.com", 1, 3},
+ {"http://example.com", 2, 3}, // Same as depth 1 because our mock only returns 2 links
+ }
+
+ for _, test := range tests {
+ visited := make(map[string]bool)
+ content, err := scrapeURL(test.url, test.depth, visited)
+ if err != nil {
+ t.Errorf("scrapeURL(%q, %d) returned error: %v", test.url, test.depth, err)
+ continue
+ }
+ if len(visited) != test.expectedCalls {
+ t.Errorf("scrapeURL(%q, %d) made %d calls, expected %d", test.url, test.depth, len(visited), test.expectedCalls)
+ }
+ expectedContent := "Mocked content for " + test.url
+ if !strings.Contains(content, expectedContent) {
+ t.Errorf("scrapeURL(%q, %d) content doesn't contain %q", test.url, test.depth, expectedContent)
+ }
+ }
+}
diff --git a/go.mod b/go.mod
index 47b03f2..6946697 100644
--- a/go.mod
+++ b/go.mod
@@ -5,6 +5,7 @@ go 1.23
require (
github.com/JohannesKaufmann/html-to-markdown v1.6.0
github.com/spf13/cobra v1.8.1
+ golang.org/x/time v0.6.0
)
require (
diff --git a/go.sum b/go.sum
index 6903da8..a6b8c57 100644
--- a/go.sum
+++ b/go.sum
@@ -102,6 +102,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
+golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
+golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
diff --git a/internal/config/config.go b/internal/config/config.go
index 6ffd454..0042396 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -15,15 +15,27 @@ type Config struct {
}
type ScrapeConfig struct {
- URLs []URLConfig `yaml:"urls"`
- OutputType string `yaml:"output_type"`
+ Sites []SiteConfig `yaml:"sites"`
+ OutputType string `yaml:"output_type"`
+ RequestsPerSecond float64 `yaml:"requests_per_second"`
+ BurstLimit int `yaml:"burst_limit"`
}
-type URLConfig struct {
- URL string `yaml:"url"`
- CSSLocator string `yaml:"css_locator"`
- ExcludeSelectors []string `yaml:"exclude_selectors"`
- OutputAlias string `yaml:"output_alias"`
+type SiteConfig struct {
+ BaseURL string `yaml:"base_url"`
+ CSSLocator string `yaml:"css_locator"`
+ ExcludeSelectors []string `yaml:"exclude_selectors"`
+ MaxDepth int `yaml:"max_depth"`
+ AllowedPaths []string `yaml:"allowed_paths"`
+ ExcludePaths []string `yaml:"exclude_paths"`
+ OutputAlias string `yaml:"output_alias"`
+ PathOverrides []PathOverride `yaml:"path_overrides"`
+}
+
+type PathOverride struct {
+ Path string `yaml:"path"`
+ CSSLocator string `yaml:"css_locator"`
+ ExcludeSelectors []string `yaml:"exclude_selectors"`
}
func Load(configPath string) (*Config, error) {
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
new file mode 100644
index 0000000..a05c23f
--- /dev/null
+++ b/internal/config/config_test.go
@@ -0,0 +1,120 @@
+package config
+
+import (
+ "os"
+ "reflect"
+ "testing"
+)
+
+func TestLoad(t *testing.T) {
+ // Create a temporary config file
+ content := []byte(`
+file_types:
+ - go
+ - md
+ignore:
+ - "*.tmp"
+ - "**/*.log"
+code_generated:
+ - "generated_*.go"
+scrape:
+ sites:
+ - base_url: "https://example.com"
+ css_locator: "main"
+ exclude_selectors:
+ - ".ads"
+ max_depth: 2
+ allowed_paths:
+ - "/blog"
+ exclude_paths:
+ - "/admin"
+ output_alias: "example"
+ path_overrides:
+ - path: "/special"
+ css_locator: ".special-content"
+ exclude_selectors:
+ - ".sidebar"
+ output_type: "single"
+ requests_per_second: 1.0
+ burst_limit: 5
+`)
+
+ tmpfile, err := os.CreateTemp("", "config*.yml")
+ if err != nil {
+ t.Fatalf("Failed to create temp file: %v", err)
+ }
+ defer os.Remove(tmpfile.Name())
+
+ if _, err := tmpfile.Write(content); err != nil {
+ t.Fatalf("Failed to write to temp file: %v", err)
+ }
+ if err := tmpfile.Close(); err != nil {
+ t.Fatalf("Failed to close temp file: %v", err)
+ }
+
+ // Test loading the config
+ config, err := Load(tmpfile.Name())
+ if err != nil {
+ t.Fatalf("Load() failed: %v", err)
+ }
+
+ // Check if the loaded config matches the expected values
+ expectedConfig := &Config{
+ FileTypes: []string{"go", "md"},
+ Ignore: []string{"*.tmp", "**/*.log"},
+ CodeGenerated: []string{"generated_*.go"},
+ Scrape: ScrapeConfig{
+ Sites: []SiteConfig{
+ {
+ BaseURL: "https://example.com",
+ CSSLocator: "main",
+ ExcludeSelectors: []string{".ads"},
+ MaxDepth: 2,
+ AllowedPaths: []string{"/blog"},
+ ExcludePaths: []string{"/admin"},
+ OutputAlias: "example",
+ PathOverrides: []PathOverride{
+ {
+ Path: "/special",
+ CSSLocator: ".special-content",
+ ExcludeSelectors: []string{".sidebar"},
+ },
+ },
+ },
+ },
+ OutputType: "single",
+ RequestsPerSecond: 1.0,
+ BurstLimit: 5,
+ },
+ }
+
+ if !reflect.DeepEqual(config, expectedConfig) {
+ t.Errorf("Loaded config does not match expected config.\nGot: %+v\nWant: %+v", config, expectedConfig)
+ }
+}
+
+func TestDefaultConfigPath(t *testing.T) {
+ expected := "rollup.yml"
+ result := DefaultConfigPath()
+ if result != expected {
+ t.Errorf("DefaultConfigPath() = %q, want %q", result, expected)
+ }
+}
+
+func TestFileExists(t *testing.T) {
+ // Test with an existing file
+ tmpfile, err := os.CreateTemp("", "testfile")
+ if err != nil {
+ t.Fatalf("Failed to create temp file: %v", err)
+ }
+ defer os.Remove(tmpfile.Name())
+
+ if !FileExists(tmpfile.Name()) {
+ t.Errorf("FileExists(%q) = false, want true", tmpfile.Name())
+ }
+
+ // Test with a non-existing file
+ if FileExists("non_existing_file.txt") {
+ t.Errorf("FileExists(\"non_existing_file.txt\") = true, want false")
+ }
+}
diff --git a/internal/scraper/scraper.go b/internal/scraper/scraper.go
index f963049..91b1399 100644
--- a/internal/scraper/scraper.go
+++ b/internal/scraper/scraper.go
@@ -5,13 +5,18 @@ import (
"io/ioutil"
"log"
"math/rand"
+ "net/url"
+ "os"
"regexp"
"strings"
"time"
+ "sync"
+ "context"
"github.com/PuerkitoBio/goquery"
"github.com/playwright-community/playwright-go"
md "github.com/JohannesKaufmann/html-to-markdown"
+ "golang.org/x/time/rate"
)
var logger *log.Logger
@@ -23,57 +28,240 @@ var (
// Config holds the scraper configuration
type Config struct {
- URLs []URLConfig
+ Sites []SiteConfig
OutputType string
Verbose bool
+ Scrape ScrapeConfig
}
-// ScrapeMultipleURLs scrapes multiple URLs concurrently
-func ScrapeMultipleURLs(config Config) (map[string]string, error) {
- results := make(chan struct {
- url string
- content string
- err error
- }, len(config.URLs))
-
- for _, urlConfig := range config.URLs {
- go func(cfg URLConfig) {
- content, err := scrapeURL(cfg)
- results <- struct {
- url string
- content string
- err error
- }{cfg.URL, content, err}
- }(urlConfig)
- }
-
- scrapedContent := make(map[string]string)
- for i := 0; i < len(config.URLs); i++ {
- result := <-results
- if result.err != nil {
- logger.Printf("Error scraping %s: %v\n", result.url, result.err)
- continue
- }
- scrapedContent[result.url] = result.content
- }
-
- return scrapedContent, nil
+// ScrapeConfig holds the scraping-specific configuration
+type ScrapeConfig struct {
+ RequestsPerSecond float64
+ BurstLimit int
}
-func scrapeURL(config URLConfig) (string, error) {
- content, err := FetchWebpageContent(config.URL)
- if err != nil {
- return "", err
- }
+// SiteConfig holds configuration for a single site
+type SiteConfig struct {
+ BaseURL string
+ CSSLocator string
+ ExcludeSelectors []string
+ MaxDepth int
+ AllowedPaths []string
+ ExcludePaths []string
+ OutputAlias string
+ PathOverrides []PathOverride
+}
- if config.CSSLocator != "" {
- content, err = ExtractContentWithCSS(content, config.CSSLocator, config.ExcludeSelectors)
- if err != nil {
- return "", err
- }
- }
+// PathOverride holds path-specific overrides
+type PathOverride struct {
+ Path string
+ CSSLocator string
+ ExcludeSelectors []string
+}
- return ProcessHTMLContent(content, Config{})
+func ScrapeSites(config Config) (map[string]string, error) {
+ logger.Println("Starting ScrapeSites function - Verbose mode is active")
+ results := make(chan struct {
+ url string
+ content string
+ err error
+ })
+
+ limiter := rate.NewLimiter(rate.Limit(config.Scrape.RequestsPerSecond), config.Scrape.BurstLimit)
+ logger.Printf("Rate limiter configured with %f requests per second and burst limit of %d\n", config.Scrape.RequestsPerSecond, config.Scrape.BurstLimit)
+
+ var wg sync.WaitGroup
+ totalURLs := 0
+ for _, site := range config.Sites {
+ logger.Printf("Processing site: %s\n", site.BaseURL)
+ wg.Add(1)
+ go func(site SiteConfig) {
+ defer wg.Done()
+ for _, path := range site.AllowedPaths {
+ fullURL := site.BaseURL + path
+ totalURLs++
+ logger.Printf("Queueing URL for scraping: %s\n", fullURL)
+ scrapeSingleURL(fullURL, site, config, results, limiter)
+ }
+ }(site)
+ }
+
+ go func() {
+ wg.Wait()
+ close(results)
+ logger.Println("All goroutines completed, results channel closed")
+ }()
+
+ scrapedContent := make(map[string]string)
+ for result := range results {
+ if result.err != nil {
+ logger.Printf("Error scraping %s: %v\n", result.url, result.err)
+ continue
+ }
+ logger.Printf("Successfully scraped content from %s (length: %d)\n", result.url, len(result.content))
+ scrapedContent[result.url] = result.content
+ }
+
+ logger.Printf("Total URLs processed: %d\n", totalURLs)
+ logger.Printf("Successfully scraped content from %d URLs\n", len(scrapedContent))
+
+ return scrapedContent, nil
+}
+
+func scrapeSingleURL(url string, site SiteConfig, config Config, results chan<- struct {
+ url string
+ content string
+ err error
+}, limiter *rate.Limiter) {
+ logger.Printf("Starting to scrape URL: %s\n", url)
+
+ // Wait for rate limiter before making the request
+ err := limiter.Wait(context.Background())
+ if err != nil {
+ logger.Printf("Rate limiter error for %s: %v\n", url, err)
+ results <- struct {
+ url string
+ content string
+ err error
+ }{url, "", fmt.Errorf("rate limiter error: %v", err)}
+ return
+ }
+
+ cssLocator, excludeSelectors := getOverrides(url, site)
+ logger.Printf("Using CSS locator for %s: %s\n", url, cssLocator)
+ logger.Printf("Exclude selectors for %s: %v\n", url, excludeSelectors)
+
+ content, err := scrapeURL(url, cssLocator, excludeSelectors)
+ if err != nil {
+ logger.Printf("Error scraping %s: %v\n", url, err)
+ results <- struct {
+ url string
+ content string
+ err error
+ }{url, "", err}
+ return
+ }
+
+ if content == "" {
+ logger.Printf("Warning: Empty content scraped from %s\n", url)
+ } else {
+ logger.Printf("Successfully scraped content from %s (length: %d)\n", url, len(content))
+ }
+
+ results <- struct {
+ url string
+ content string
+ err error
+ }{url, content, nil}
+}
+
+func scrapeSite(site SiteConfig, config Config, results chan<- struct {
+ url string
+ content string
+ err error
+}, limiter *rate.Limiter) {
+ visited := make(map[string]bool)
+ queue := []string{site.BaseURL}
+
+ for len(queue) > 0 {
+ url := queue[0]
+ queue = queue[1:]
+
+ if visited[url] {
+ continue
+ }
+ visited[url] = true
+
+ if !isAllowedURL(url, site) {
+ continue
+ }
+
+ // Wait for rate limiter before making the request
+ err := limiter.Wait(context.Background())
+ if err != nil {
+ results <- struct {
+ url string
+ content string
+ err error
+ }{url, "", fmt.Errorf("rate limiter error: %v", err)}
+ continue
+ }
+
+ cssLocator, excludeSelectors := getOverrides(url, site)
+ content, err := scrapeURL(url, cssLocator, excludeSelectors)
+ results <- struct {
+ url string
+ content string
+ err error
+ }{url, content, err}
+
+ if len(visited) < site.MaxDepth {
+ links, _ := ExtractLinks(url)
+ for _, link := range links {
+ if !visited[link] && isAllowedURL(link, site) {
+ queue = append(queue, link)
+ }
+ }
+ }
+ }
+}
+
+func isAllowedURL(urlStr string, site SiteConfig) bool {
+ parsedURL, err := url.Parse(urlStr)
+ if err != nil {
+ return false
+ }
+
+ baseURL, _ := url.Parse(site.BaseURL)
+ if parsedURL.Host != baseURL.Host {
+ return false
+ }
+
+ path := parsedURL.Path
+ for _, allowedPath := range site.AllowedPaths {
+ if strings.HasPrefix(path, allowedPath) {
+ for _, excludePath := range site.ExcludePaths {
+ if strings.HasPrefix(path, excludePath) {
+ return false
+ }
+ }
+ return true
+ }
+ }
+
+ return false
+}
+
+func getOverrides(urlStr string, site SiteConfig) (string, []string) {
+ parsedURL, _ := url.Parse(urlStr)
+ path := parsedURL.Path
+
+ for _, override := range site.PathOverrides {
+ if strings.HasPrefix(path, override.Path) {
+ if override.CSSLocator != "" {
+ return override.CSSLocator, override.ExcludeSelectors
+ }
+ return site.CSSLocator, override.ExcludeSelectors
+ }
+ }
+
+ return site.CSSLocator, site.ExcludeSelectors
+}
+
+func scrapeURL(url, cssLocator string, excludeSelectors []string) (string, error) {
+ content, err := FetchWebpageContent(url)
+ if err != nil {
+ return "", err
+ }
+
+ if cssLocator != "" {
+ content, err = ExtractContentWithCSS(content, cssLocator, excludeSelectors)
+ if err != nil {
+ return "", err
+ }
+ }
+
+ return ProcessHTMLContent(content, Config{})
}
func getFilenameFromContent(content, url string) string {
@@ -106,7 +294,7 @@ type URLConfig struct {
// SetupLogger initializes the logger based on the verbose flag
func SetupLogger(verbose bool) {
if verbose {
- logger = log.New(log.Writer(), "SCRAPER: ", log.LstdFlags)
+ logger = log.New(os.Stdout, "SCRAPER: ", log.LstdFlags)
} else {
logger = log.New(ioutil.Discard, "", 0)
}
diff --git a/internal/scraper/scraper_test.go b/internal/scraper/scraper_test.go
new file mode 100644
index 0000000..df36dec
--- /dev/null
+++ b/internal/scraper/scraper_test.go
@@ -0,0 +1,169 @@
+package scraper
+
+import (
+ "testing"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+)
+
+func TestIsAllowedURL(t *testing.T) {
+ site := SiteConfig{
+ BaseURL: "https://example.com",
+ AllowedPaths: []string{"/blog", "/products"},
+ ExcludePaths: []string{"/admin", "/private"},
+ }
+
+ tests := []struct {
+ url string
+ expected bool
+ }{
+ {"https://example.com/blog/post1", true},
+ {"https://example.com/products/item1", true},
+ {"https://example.com/admin/dashboard", false},
+ {"https://example.com/private/data", false},
+ {"https://example.com/other/page", false},
+ {"https://othersite.com/blog/post1", false},
+ }
+
+ for _, test := range tests {
+ result := isAllowedURL(test.url, site)
+ if result != test.expected {
+ t.Errorf("isAllowedURL(%q) = %v, want %v", test.url, result, test.expected)
+ }
+ }
+}
+
+func TestGetOverrides(t *testing.T) {
+ site := SiteConfig{
+ CSSLocator: "main",
+ ExcludeSelectors: []string{".ads"},
+ PathOverrides: []PathOverride{
+ {
+ Path: "/special",
+ CSSLocator: ".special-content",
+ ExcludeSelectors: []string{".sidebar"},
+ },
+ },
+ }
+
+ tests := []struct {
+ url string
+ expectedLocator string
+ expectedExcludes []string
+ }{
+ {"https://example.com/normal", "main", []string{".ads"}},
+ {"https://example.com/special", ".special-content", []string{".sidebar"}},
+ {"https://example.com/special/page", ".special-content", []string{".sidebar"}},
+ }
+
+ for _, test := range tests {
+ locator, excludes := getOverrides(test.url, site)
+ if locator != test.expectedLocator {
+ t.Errorf("getOverrides(%q) locator = %q, want %q", test.url, locator, test.expectedLocator)
+ }
+ if !reflect.DeepEqual(excludes, test.expectedExcludes) {
+ t.Errorf("getOverrides(%q) excludes = %v, want %v", test.url, excludes, test.expectedExcludes)
+ }
+ }
+}
+
+func TestExtractContentWithCSS(t *testing.T) {
+ html := `
+
+
+
+ Main Content
+ This is the main content.
+ Advertisement
+
+
+
+
+ `
+
+ tests := []struct {
+ includeSelector string
+ excludeSelectors []string
+ expected string
+ }{
+ {"main", nil, "Main Content
\nThis is the main content.
\nAdvertisement
"},
+ {"main", []string{".ads"}, "Main Content
\nThis is the main content.
"},
+ {"aside", nil, "Sidebar content"},
+ }
+
+ for _, test := range tests {
+ result, err := ExtractContentWithCSS(html, test.includeSelector, test.excludeSelectors)
+ if err != nil {
+ t.Errorf("ExtractContentWithCSS() returned error: %v", err)
+ continue
+ }
+ if strings.TrimSpace(result) != strings.TrimSpace(test.expected) {
+ t.Errorf("ExtractContentWithCSS() = %q, want %q", result, test.expected)
+ }
+ }
+}
+
+func TestProcessHTMLContent(t *testing.T) {
+ html := `
+
+
+ Test Heading
+ This is a test paragraph.
+
+
+
+ `
+
+ expected := strings.TrimSpace(`
+# Test Heading
+
+This is a **test** paragraph.
+
+- Item 1
+- Item 2
+ `)
+
+ result, err := ProcessHTMLContent(html, Config{})
+ if err != nil {
+ t.Fatalf("ProcessHTMLContent() returned error: %v", err)
+ }
+
+ if strings.TrimSpace(result) != expected {
+ t.Errorf("ProcessHTMLContent() = %q, want %q", result, expected)
+ }
+}
+
+func TestExtractLinks(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/html")
+ w.Write([]byte(`
+
+
+ Page 1
+ Page 2
+ Other Site
+
+
+ `))
+ }))
+ defer server.Close()
+
+ links, err := ExtractLinks(server.URL)
+ if err != nil {
+ t.Fatalf("ExtractLinks() returned error: %v", err)
+ }
+
+ expectedLinks := []string{
+ "https://example.com/page1",
+ "https://example.com/page2",
+ "https://othersite.com",
+ }
+
+ if !reflect.DeepEqual(links, expectedLinks) {
+ t.Errorf("ExtractLinks() = %v, want %v", links, expectedLinks)
+ }
+}