diff --git a/cmd/files_test.go b/cmd/files_test.go new file mode 100644 index 0000000..3cf875e --- /dev/null +++ b/cmd/files_test.go @@ -0,0 +1,147 @@ +package cmd + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/tnypxl/rollup/internal/config" +) + +func TestMatchGlob(t *testing.T) { + tests := []struct { + pattern string + path string + expected bool + }{ + {"*.go", "file.go", true}, + {"*.go", "file.txt", false}, + {"**/*.go", "dir/file.go", true}, + {"**/*.go", "dir/subdir/file.go", true}, + {"dir/*.go", "dir/file.go", true}, + {"dir/*.go", "otherdir/file.go", false}, + } + + for _, test := range tests { + result := matchGlob(test.pattern, test.path) + if result != test.expected { + t.Errorf("matchGlob(%q, %q) = %v; want %v", test.pattern, test.path, result, test.expected) + } + } +} + +func TestIsCodeGenerated(t *testing.T) { + patterns := []string{"generated_*.go", "**/auto_*.go"} + tests := []struct { + path string + expected bool + }{ + {"generated_file.go", true}, + {"normal_file.go", false}, + {"subdir/auto_file.go", true}, + {"subdir/normal_file.go", false}, + } + + for _, test := range tests { + result := isCodeGenerated(test.path, patterns) + if result != test.expected { + t.Errorf("isCodeGenerated(%q, %v) = %v; want %v", test.path, patterns, result, test.expected) + } + } +} + +func TestIsIgnored(t *testing.T) { + patterns := []string{"*.tmp", "**/*.log"} + tests := []struct { + path string + expected bool + }{ + {"file.tmp", true}, + {"file.go", false}, + {"subdir/file.log", true}, + {"subdir/file.txt", false}, + } + + for _, test := range tests { + result := isIgnored(test.path, patterns) + if result != test.expected { + t.Errorf("isIgnored(%q, %v) = %v; want %v", test.path, patterns, result, test.expected) + } + } +} + +func TestRunRollup(t *testing.T) { + // Create a temporary directory for testing + tempDir, err := os.MkdirTemp("", "rollup_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create some test files + files := map[string]string{ + "file1.go": "package main\n\nfunc main() {}\n", + "file2.txt": "This is a text file.\n", + "subdir/file3.go": "package subdir\n\nfunc Func() {}\n", + "subdir/file4.json": "{\"key\": \"value\"}\n", + } + + for name, content := range files { + path := filepath.Join(tempDir, name) + err := os.MkdirAll(filepath.Dir(path), 0755) + if err != nil { + t.Fatalf("Failed to create directory: %v", err) + } + err = os.WriteFile(path, []byte(content), 0644) + if err != nil { + t.Fatalf("Failed to write file: %v", err) + } + } + + // Set up test configuration + cfg = &config.Config{ + FileTypes: []string{"go", "txt"}, + Ignore: []string{"*.json"}, + CodeGenerated: []string{}, + } + path = tempDir + + // Run the rollup + err = runRollup() + if err != nil { + t.Fatalf("runRollup() failed: %v", err) + } + + // Check if the output file was created + outputFiles, err := filepath.Glob(filepath.Join(tempDir, "*.rollup.md")) + if err != nil { + t.Fatalf("Failed to glob output files: %v", err) + } + if len(outputFiles) != 1 { + t.Fatalf("Expected 1 output file, got %d", len(outputFiles)) + } + + // Read the content of the output file + content, err := os.ReadFile(outputFiles[0]) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + // Check if the content includes the expected files + expectedContent := []string{ + "# File: file1.go", + "# File: file2.txt", + "# File: subdir/file3.go", + } + for _, expected := range expectedContent { + if !strings.Contains(string(content), expected) { + t.Errorf("Output file does not contain expected content: %s", expected) + } + } + + // Check if the ignored file is not included + if strings.Contains(string(content), "file4.json") { + t.Errorf("Output file contains ignored file: file4.json") + } +} diff --git a/cmd/web.go b/cmd/web.go index 26ac38c..f35418c 100644 --- a/cmd/web.go +++ b/cmd/web.go @@ -2,6 +2,8 @@ package cmd import ( "fmt" + "io/ioutil" + "log" "net/url" "os" "regexp" @@ -9,6 +11,7 @@ import ( "time" "github.com/spf13/cobra" + "github.com/tnypxl/rollup/internal/config" "github.com/tnypxl/rollup/internal/scraper" ) @@ -38,47 +41,93 @@ func init() { } func runWeb(cmd *cobra.Command, args []string) error { - scraperConfig.Verbose = verbose + scraper.SetupLogger(verbose) + logger := log.New(os.Stdout, "WEB: ", log.LstdFlags) + if !verbose { + logger.SetOutput(ioutil.Discard) + } + logger.Printf("Starting web scraping process with verbose mode: %v", verbose) + scraperConfig.Verbose = verbose - // Use config if available, otherwise use command-line flags - var urlConfigs []scraper.URLConfig - if len(urls) == 0 && len(cfg.Scrape.URLs) > 0 { - urlConfigs = make([]scraper.URLConfig, len(cfg.Scrape.URLs)) - for i, u := range cfg.Scrape.URLs { - urlConfigs[i] = scraper.URLConfig{ - URL: u.URL, - CSSLocator: u.CSSLocator, - ExcludeSelectors: u.ExcludeSelectors, - OutputAlias: u.OutputAlias, - } - } - } else { - urlConfigs = make([]scraper.URLConfig, len(urls)) - for i, u := range urls { - urlConfigs[i] = scraper.URLConfig{URL: u, CSSLocator: includeSelector} - } - } + var siteConfigs []scraper.SiteConfig + if len(cfg.Scrape.Sites) > 0 { + logger.Printf("Using configuration from rollup.yml for %d sites", len(cfg.Scrape.Sites)) + siteConfigs = make([]scraper.SiteConfig, len(cfg.Scrape.Sites)) + for i, site := range cfg.Scrape.Sites { + siteConfigs[i] = scraper.SiteConfig{ + BaseURL: site.BaseURL, + CSSLocator: site.CSSLocator, + ExcludeSelectors: site.ExcludeSelectors, + MaxDepth: site.MaxDepth, + AllowedPaths: site.AllowedPaths, + ExcludePaths: site.ExcludePaths, + OutputAlias: site.OutputAlias, + PathOverrides: convertPathOverrides(site.PathOverrides), + } + logger.Printf("Site %d configuration: BaseURL=%s, CSSLocator=%s, MaxDepth=%d, AllowedPaths=%v", + i+1, site.BaseURL, site.CSSLocator, site.MaxDepth, site.AllowedPaths) + } + } else { + logger.Printf("No sites defined in rollup.yml, falling back to URL-based configuration") + siteConfigs = make([]scraper.SiteConfig, len(urls)) + for i, u := range urls { + siteConfigs[i] = scraper.SiteConfig{ + BaseURL: u, + CSSLocator: includeSelector, + ExcludeSelectors: excludeSelectors, + MaxDepth: depth, + } + logger.Printf("URL %d configuration: BaseURL=%s, CSSLocator=%s, MaxDepth=%d", + i+1, u, includeSelector, depth) + } + } - if len(urlConfigs) == 0 { - return fmt.Errorf("no URLs provided. Use --urls flag with comma-separated URLs or set 'scrape.urls' in the rollup.yml file") - } + if len(siteConfigs) == 0 { + logger.Println("Error: No sites or URLs provided") + return fmt.Errorf("no sites or URLs provided. Use --urls flag with comma-separated URLs or set 'scrape.sites' in the rollup.yml file") + } - scraperConfig := scraper.Config{ - URLs: urlConfigs, - OutputType: outputType, - Verbose: verbose, - } + // Set default values for rate limiting + defaultRequestsPerSecond := 1.0 + defaultBurstLimit := 3 - scrapedContent, err := scraper.ScrapeMultipleURLs(scraperConfig) - if err != nil { - return fmt.Errorf("error scraping content: %v", err) - } + // Use default values if not set in the configuration + requestsPerSecond := cfg.Scrape.RequestsPerSecond + if requestsPerSecond == 0 { + requestsPerSecond = defaultRequestsPerSecond + } + burstLimit := cfg.Scrape.BurstLimit + if burstLimit == 0 { + burstLimit = defaultBurstLimit + } - if outputType == "single" { - return writeSingleFile(scrapedContent) - } else { - return writeMultipleFiles(scrapedContent) - } + scraperConfig := scraper.Config{ + Sites: siteConfigs, + OutputType: outputType, + Verbose: verbose, + Scrape: scraper.ScrapeConfig{ + RequestsPerSecond: requestsPerSecond, + BurstLimit: burstLimit, + }, + } + logger.Printf("Scraper configuration: OutputType=%s, RequestsPerSecond=%f, BurstLimit=%d", + outputType, requestsPerSecond, burstLimit) + + logger.Println("Starting scraping process") + scrapedContent, err := scraper.ScrapeSites(scraperConfig) + if err != nil { + logger.Printf("Error occurred during scraping: %v", err) + return fmt.Errorf("error scraping content: %v", err) + } + logger.Printf("Scraping completed. Total content scraped: %d", len(scrapedContent)) + + if outputType == "single" { + logger.Println("Writing content to a single file") + return writeSingleFile(scrapedContent) + } else { + logger.Println("Writing content to multiple files") + return writeMultipleFiles(scrapedContent) + } } func writeSingleFile(content map[string]string) error { @@ -102,20 +151,26 @@ func writeSingleFile(content map[string]string) error { func writeMultipleFiles(content map[string]string) error { for url, c := range content { - filename := getFilenameFromContent(c, url) + filename, err := getFilenameFromContent(c, url) + if err != nil { + return fmt.Errorf("error generating filename for %s: %v", url, err) + } + file, err := os.Create(filename) if err != nil { return fmt.Errorf("error creating output file %s: %v", filename, err) } - _, err = fmt.Fprintf(file, "# Content from %s\n\n%s", url, c) - file.Close() + _, err = file.WriteString(fmt.Sprintf("# Content from %s\n\n%s\n", url, c)) if err != nil { + file.Close() return fmt.Errorf("error writing content to file %s: %v", filename, err) } + file.Close() fmt.Printf("Content from %s has been saved to %s\n", url, filename) } + return nil } @@ -136,13 +191,13 @@ func scrapeURL(urlStr string, depth int, visited map[string]bool) (string, error visited[urlStr] = true - content, err := extractAndConvertContent(urlStr) + content, err := testExtractAndConvertContent(urlStr) if err != nil { return "", err } if depth > 0 { - links, err := scraper.ExtractLinks(urlStr) + links, err := testExtractLinks(urlStr) if err != nil { return content, fmt.Errorf("error extracting links: %v", err) } @@ -160,6 +215,9 @@ func scrapeURL(urlStr string, depth int, visited map[string]bool) (string, error return content, nil } +var testExtractAndConvertContent = extractAndConvertContent +var testExtractLinks = scraper.ExtractLinks + func extractAndConvertContent(urlStr string) (string, error) { content, err := scraper.FetchWebpageContent(urlStr) if err != nil { @@ -187,17 +245,32 @@ func extractAndConvertContent(urlStr string) (string, error) { return header + markdown + "\n\n", nil } -func getFilenameFromContent(content, url string) string { +func getFilenameFromContent(content, urlStr string) (string, error) { // Try to extract title from content titleStart := strings.Index(content, "") titleEnd := strings.Index(content, "") if titleStart != -1 && titleEnd != -1 && titleEnd > titleStart { - title := content[titleStart+7 : titleEnd] - return sanitizeFilename(title) + ".md" + title := strings.TrimSpace(content[titleStart+7 : titleEnd]) + if title != "" { + return sanitizeFilename(title) + ".rollup.md", nil + } } - // If no title found, use the URL - return sanitizeFilename(url) + ".md" + // If no title found or title is empty, use the URL + parsedURL, err := url.Parse(urlStr) + if err != nil { + return "", fmt.Errorf("invalid URL: %v", err) + } + + if parsedURL.Host == "" { + return "", fmt.Errorf("invalid URL: missing host") + } + + filename := parsedURL.Host + if parsedURL.Path != "" && parsedURL.Path != "/" { + filename += strings.TrimSuffix(parsedURL.Path, "/") + } + return sanitizeFilename(filename) + ".rollup.md", nil } func sanitizeFilename(name string) string { @@ -215,3 +288,15 @@ func sanitizeFilename(name string) string { return name } + +func convertPathOverrides(configOverrides []config.PathOverride) []scraper.PathOverride { + scraperOverrides := make([]scraper.PathOverride, len(configOverrides)) + for i, override := range configOverrides { + scraperOverrides[i] = scraper.PathOverride{ + Path: override.Path, + CSSLocator: override.CSSLocator, + ExcludeSelectors: override.ExcludeSelectors, + } + } + return scraperOverrides +} diff --git a/cmd/web_test.go b/cmd/web_test.go new file mode 100644 index 0000000..8e470be --- /dev/null +++ b/cmd/web_test.go @@ -0,0 +1,154 @@ +package cmd + +import ( + "testing" + "strings" + "github.com/tnypxl/rollup/internal/config" +) + +func TestConvertPathOverrides(t *testing.T) { + configOverrides := []config.PathOverride{ + { + Path: "/blog", + CSSLocator: "article", + ExcludeSelectors: []string{".ads", ".comments"}, + }, + { + Path: "/products", + CSSLocator: ".product-description", + ExcludeSelectors: []string{".related-items"}, + }, + } + + scraperOverrides := convertPathOverrides(configOverrides) + + if len(scraperOverrides) != len(configOverrides) { + t.Errorf("Expected %d overrides, got %d", len(configOverrides), len(scraperOverrides)) + } + + for i, override := range scraperOverrides { + if override.Path != configOverrides[i].Path { + t.Errorf("Expected Path %s, got %s", configOverrides[i].Path, override.Path) + } + if override.CSSLocator != configOverrides[i].CSSLocator { + t.Errorf("Expected CSSLocator %s, got %s", configOverrides[i].CSSLocator, override.CSSLocator) + } + if len(override.ExcludeSelectors) != len(configOverrides[i].ExcludeSelectors) { + t.Errorf("Expected %d ExcludeSelectors, got %d", len(configOverrides[i].ExcludeSelectors), len(override.ExcludeSelectors)) + } + for j, selector := range override.ExcludeSelectors { + if selector != configOverrides[i].ExcludeSelectors[j] { + t.Errorf("Expected ExcludeSelector %s, got %s", configOverrides[i].ExcludeSelectors[j], selector) + } + } + } +} + +func TestSanitizeFilename(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"Hello, World!", "Hello_World"}, + {"file/with/path", "file_with_path"}, + {"file.with.dots", "file_with_dots"}, + {"___leading_underscores___", "leading_underscores"}, + {"", "untitled"}, + {"!@#$%^&*()", "untitled"}, + } + + for _, test := range tests { + result := sanitizeFilename(test.input) + if result != test.expected { + t.Errorf("sanitizeFilename(%q) = %q; want %q", test.input, result, test.expected) + } + } +} + +func TestGetFilenameFromContent(t *testing.T) { + tests := []struct { + content string + url string + expected string + expectErr bool + }{ + {"Test Page", "http://example.com", "Test_Page.rollup.md", false}, + {"No title here", "http://example.com/page", "example_com_page.rollup.md", false}, + {" Trim Me ", "http://example.com", "Trim_Me.rollup.md", false}, + {"", "http://example.com", "example_com.rollup.md", false}, + {" ", "http://example.com", "example_com.rollup.md", false}, + {"Invalid URL", "not a valid url", "", true}, + {"No host", "http://", "", true}, + } + + for _, test := range tests { + result, err := getFilenameFromContent(test.content, test.url) + if test.expectErr { + if err == nil { + t.Errorf("getFilenameFromContent(%q, %q) expected an error, but got none", test.content, test.url) + } + } else { + if err != nil { + t.Errorf("getFilenameFromContent(%q, %q) unexpected error: %v", test.content, test.url, err) + } + if result != test.expected { + t.Errorf("getFilenameFromContent(%q, %q) = %q; want %q", test.content, test.url, result, test.expected) + } + } + } +} + +// Mock functions for testing +func mockExtractAndConvertContent(urlStr string) (string, error) { + return "Mocked content for " + urlStr, nil +} + +func mockExtractLinks(urlStr string) ([]string, error) { + return []string{"http://example.com/link1", "http://example.com/link2"}, nil +} + +func TestScrapeURL(t *testing.T) { + // Store the original functions + originalExtractAndConvertContent := testExtractAndConvertContent + originalExtractLinks := testExtractLinks + + // Define mock functions + testExtractAndConvertContent = func(urlStr string) (string, error) { + return "Mocked content for " + urlStr, nil + } + testExtractLinks = func(urlStr string) ([]string, error) { + return []string{"http://example.com/link1", "http://example.com/link2"}, nil + } + + // Defer the restoration of original functions + defer func() { + testExtractAndConvertContent = originalExtractAndConvertContent + testExtractLinks = originalExtractLinks + }() + + tests := []struct { + url string + depth int + expectedCalls int + }{ + {"http://example.com", 0, 1}, + {"http://example.com", 1, 3}, + {"http://example.com", 2, 3}, // Same as depth 1 because our mock only returns 2 links + } + + for _, test := range tests { + visited := make(map[string]bool) + content, err := scrapeURL(test.url, test.depth, visited) + if err != nil { + t.Errorf("scrapeURL(%q, %d) returned error: %v", test.url, test.depth, err) + continue + } + if len(visited) != test.expectedCalls { + t.Errorf("scrapeURL(%q, %d) made %d calls, expected %d", test.url, test.depth, len(visited), test.expectedCalls) + } + expectedContent := "Mocked content for " + test.url + if !strings.Contains(content, expectedContent) { + t.Errorf("scrapeURL(%q, %d) content doesn't contain %q", test.url, test.depth, expectedContent) + } + } +} diff --git a/go.mod b/go.mod index 47b03f2..6946697 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23 require ( github.com/JohannesKaufmann/html-to-markdown v1.6.0 github.com/spf13/cobra v1.8.1 + golang.org/x/time v0.6.0 ) require ( diff --git a/go.sum b/go.sum index 6903da8..a6b8c57 100644 --- a/go.sum +++ b/go.sum @@ -102,6 +102,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= +golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/internal/config/config.go b/internal/config/config.go index 6ffd454..0042396 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,15 +15,27 @@ type Config struct { } type ScrapeConfig struct { - URLs []URLConfig `yaml:"urls"` - OutputType string `yaml:"output_type"` + Sites []SiteConfig `yaml:"sites"` + OutputType string `yaml:"output_type"` + RequestsPerSecond float64 `yaml:"requests_per_second"` + BurstLimit int `yaml:"burst_limit"` } -type URLConfig struct { - URL string `yaml:"url"` - CSSLocator string `yaml:"css_locator"` - ExcludeSelectors []string `yaml:"exclude_selectors"` - OutputAlias string `yaml:"output_alias"` +type SiteConfig struct { + BaseURL string `yaml:"base_url"` + CSSLocator string `yaml:"css_locator"` + ExcludeSelectors []string `yaml:"exclude_selectors"` + MaxDepth int `yaml:"max_depth"` + AllowedPaths []string `yaml:"allowed_paths"` + ExcludePaths []string `yaml:"exclude_paths"` + OutputAlias string `yaml:"output_alias"` + PathOverrides []PathOverride `yaml:"path_overrides"` +} + +type PathOverride struct { + Path string `yaml:"path"` + CSSLocator string `yaml:"css_locator"` + ExcludeSelectors []string `yaml:"exclude_selectors"` } func Load(configPath string) (*Config, error) { diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..a05c23f --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,120 @@ +package config + +import ( + "os" + "reflect" + "testing" +) + +func TestLoad(t *testing.T) { + // Create a temporary config file + content := []byte(` +file_types: + - go + - md +ignore: + - "*.tmp" + - "**/*.log" +code_generated: + - "generated_*.go" +scrape: + sites: + - base_url: "https://example.com" + css_locator: "main" + exclude_selectors: + - ".ads" + max_depth: 2 + allowed_paths: + - "/blog" + exclude_paths: + - "/admin" + output_alias: "example" + path_overrides: + - path: "/special" + css_locator: ".special-content" + exclude_selectors: + - ".sidebar" + output_type: "single" + requests_per_second: 1.0 + burst_limit: 5 +`) + + tmpfile, err := os.CreateTemp("", "config*.yml") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write(content); err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + if err := tmpfile.Close(); err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } + + // Test loading the config + config, err := Load(tmpfile.Name()) + if err != nil { + t.Fatalf("Load() failed: %v", err) + } + + // Check if the loaded config matches the expected values + expectedConfig := &Config{ + FileTypes: []string{"go", "md"}, + Ignore: []string{"*.tmp", "**/*.log"}, + CodeGenerated: []string{"generated_*.go"}, + Scrape: ScrapeConfig{ + Sites: []SiteConfig{ + { + BaseURL: "https://example.com", + CSSLocator: "main", + ExcludeSelectors: []string{".ads"}, + MaxDepth: 2, + AllowedPaths: []string{"/blog"}, + ExcludePaths: []string{"/admin"}, + OutputAlias: "example", + PathOverrides: []PathOverride{ + { + Path: "/special", + CSSLocator: ".special-content", + ExcludeSelectors: []string{".sidebar"}, + }, + }, + }, + }, + OutputType: "single", + RequestsPerSecond: 1.0, + BurstLimit: 5, + }, + } + + if !reflect.DeepEqual(config, expectedConfig) { + t.Errorf("Loaded config does not match expected config.\nGot: %+v\nWant: %+v", config, expectedConfig) + } +} + +func TestDefaultConfigPath(t *testing.T) { + expected := "rollup.yml" + result := DefaultConfigPath() + if result != expected { + t.Errorf("DefaultConfigPath() = %q, want %q", result, expected) + } +} + +func TestFileExists(t *testing.T) { + // Test with an existing file + tmpfile, err := os.CreateTemp("", "testfile") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpfile.Name()) + + if !FileExists(tmpfile.Name()) { + t.Errorf("FileExists(%q) = false, want true", tmpfile.Name()) + } + + // Test with a non-existing file + if FileExists("non_existing_file.txt") { + t.Errorf("FileExists(\"non_existing_file.txt\") = true, want false") + } +} diff --git a/internal/scraper/scraper.go b/internal/scraper/scraper.go index f963049..91b1399 100644 --- a/internal/scraper/scraper.go +++ b/internal/scraper/scraper.go @@ -5,13 +5,18 @@ import ( "io/ioutil" "log" "math/rand" + "net/url" + "os" "regexp" "strings" "time" + "sync" + "context" "github.com/PuerkitoBio/goquery" "github.com/playwright-community/playwright-go" md "github.com/JohannesKaufmann/html-to-markdown" + "golang.org/x/time/rate" ) var logger *log.Logger @@ -23,57 +28,240 @@ var ( // Config holds the scraper configuration type Config struct { - URLs []URLConfig + Sites []SiteConfig OutputType string Verbose bool + Scrape ScrapeConfig } -// ScrapeMultipleURLs scrapes multiple URLs concurrently -func ScrapeMultipleURLs(config Config) (map[string]string, error) { - results := make(chan struct { - url string - content string - err error - }, len(config.URLs)) - - for _, urlConfig := range config.URLs { - go func(cfg URLConfig) { - content, err := scrapeURL(cfg) - results <- struct { - url string - content string - err error - }{cfg.URL, content, err} - }(urlConfig) - } - - scrapedContent := make(map[string]string) - for i := 0; i < len(config.URLs); i++ { - result := <-results - if result.err != nil { - logger.Printf("Error scraping %s: %v\n", result.url, result.err) - continue - } - scrapedContent[result.url] = result.content - } - - return scrapedContent, nil +// ScrapeConfig holds the scraping-specific configuration +type ScrapeConfig struct { + RequestsPerSecond float64 + BurstLimit int } -func scrapeURL(config URLConfig) (string, error) { - content, err := FetchWebpageContent(config.URL) - if err != nil { - return "", err - } +// SiteConfig holds configuration for a single site +type SiteConfig struct { + BaseURL string + CSSLocator string + ExcludeSelectors []string + MaxDepth int + AllowedPaths []string + ExcludePaths []string + OutputAlias string + PathOverrides []PathOverride +} - if config.CSSLocator != "" { - content, err = ExtractContentWithCSS(content, config.CSSLocator, config.ExcludeSelectors) - if err != nil { - return "", err - } - } +// PathOverride holds path-specific overrides +type PathOverride struct { + Path string + CSSLocator string + ExcludeSelectors []string +} - return ProcessHTMLContent(content, Config{}) +func ScrapeSites(config Config) (map[string]string, error) { + logger.Println("Starting ScrapeSites function - Verbose mode is active") + results := make(chan struct { + url string + content string + err error + }) + + limiter := rate.NewLimiter(rate.Limit(config.Scrape.RequestsPerSecond), config.Scrape.BurstLimit) + logger.Printf("Rate limiter configured with %f requests per second and burst limit of %d\n", config.Scrape.RequestsPerSecond, config.Scrape.BurstLimit) + + var wg sync.WaitGroup + totalURLs := 0 + for _, site := range config.Sites { + logger.Printf("Processing site: %s\n", site.BaseURL) + wg.Add(1) + go func(site SiteConfig) { + defer wg.Done() + for _, path := range site.AllowedPaths { + fullURL := site.BaseURL + path + totalURLs++ + logger.Printf("Queueing URL for scraping: %s\n", fullURL) + scrapeSingleURL(fullURL, site, config, results, limiter) + } + }(site) + } + + go func() { + wg.Wait() + close(results) + logger.Println("All goroutines completed, results channel closed") + }() + + scrapedContent := make(map[string]string) + for result := range results { + if result.err != nil { + logger.Printf("Error scraping %s: %v\n", result.url, result.err) + continue + } + logger.Printf("Successfully scraped content from %s (length: %d)\n", result.url, len(result.content)) + scrapedContent[result.url] = result.content + } + + logger.Printf("Total URLs processed: %d\n", totalURLs) + logger.Printf("Successfully scraped content from %d URLs\n", len(scrapedContent)) + + return scrapedContent, nil +} + +func scrapeSingleURL(url string, site SiteConfig, config Config, results chan<- struct { + url string + content string + err error +}, limiter *rate.Limiter) { + logger.Printf("Starting to scrape URL: %s\n", url) + + // Wait for rate limiter before making the request + err := limiter.Wait(context.Background()) + if err != nil { + logger.Printf("Rate limiter error for %s: %v\n", url, err) + results <- struct { + url string + content string + err error + }{url, "", fmt.Errorf("rate limiter error: %v", err)} + return + } + + cssLocator, excludeSelectors := getOverrides(url, site) + logger.Printf("Using CSS locator for %s: %s\n", url, cssLocator) + logger.Printf("Exclude selectors for %s: %v\n", url, excludeSelectors) + + content, err := scrapeURL(url, cssLocator, excludeSelectors) + if err != nil { + logger.Printf("Error scraping %s: %v\n", url, err) + results <- struct { + url string + content string + err error + }{url, "", err} + return + } + + if content == "" { + logger.Printf("Warning: Empty content scraped from %s\n", url) + } else { + logger.Printf("Successfully scraped content from %s (length: %d)\n", url, len(content)) + } + + results <- struct { + url string + content string + err error + }{url, content, nil} +} + +func scrapeSite(site SiteConfig, config Config, results chan<- struct { + url string + content string + err error +}, limiter *rate.Limiter) { + visited := make(map[string]bool) + queue := []string{site.BaseURL} + + for len(queue) > 0 { + url := queue[0] + queue = queue[1:] + + if visited[url] { + continue + } + visited[url] = true + + if !isAllowedURL(url, site) { + continue + } + + // Wait for rate limiter before making the request + err := limiter.Wait(context.Background()) + if err != nil { + results <- struct { + url string + content string + err error + }{url, "", fmt.Errorf("rate limiter error: %v", err)} + continue + } + + cssLocator, excludeSelectors := getOverrides(url, site) + content, err := scrapeURL(url, cssLocator, excludeSelectors) + results <- struct { + url string + content string + err error + }{url, content, err} + + if len(visited) < site.MaxDepth { + links, _ := ExtractLinks(url) + for _, link := range links { + if !visited[link] && isAllowedURL(link, site) { + queue = append(queue, link) + } + } + } + } +} + +func isAllowedURL(urlStr string, site SiteConfig) bool { + parsedURL, err := url.Parse(urlStr) + if err != nil { + return false + } + + baseURL, _ := url.Parse(site.BaseURL) + if parsedURL.Host != baseURL.Host { + return false + } + + path := parsedURL.Path + for _, allowedPath := range site.AllowedPaths { + if strings.HasPrefix(path, allowedPath) { + for _, excludePath := range site.ExcludePaths { + if strings.HasPrefix(path, excludePath) { + return false + } + } + return true + } + } + + return false +} + +func getOverrides(urlStr string, site SiteConfig) (string, []string) { + parsedURL, _ := url.Parse(urlStr) + path := parsedURL.Path + + for _, override := range site.PathOverrides { + if strings.HasPrefix(path, override.Path) { + if override.CSSLocator != "" { + return override.CSSLocator, override.ExcludeSelectors + } + return site.CSSLocator, override.ExcludeSelectors + } + } + + return site.CSSLocator, site.ExcludeSelectors +} + +func scrapeURL(url, cssLocator string, excludeSelectors []string) (string, error) { + content, err := FetchWebpageContent(url) + if err != nil { + return "", err + } + + if cssLocator != "" { + content, err = ExtractContentWithCSS(content, cssLocator, excludeSelectors) + if err != nil { + return "", err + } + } + + return ProcessHTMLContent(content, Config{}) } func getFilenameFromContent(content, url string) string { @@ -106,7 +294,7 @@ type URLConfig struct { // SetupLogger initializes the logger based on the verbose flag func SetupLogger(verbose bool) { if verbose { - logger = log.New(log.Writer(), "SCRAPER: ", log.LstdFlags) + logger = log.New(os.Stdout, "SCRAPER: ", log.LstdFlags) } else { logger = log.New(ioutil.Discard, "", 0) } diff --git a/internal/scraper/scraper_test.go b/internal/scraper/scraper_test.go new file mode 100644 index 0000000..df36dec --- /dev/null +++ b/internal/scraper/scraper_test.go @@ -0,0 +1,169 @@ +package scraper + +import ( + "testing" + "net/http" + "net/http/httptest" + "strings" +) + +func TestIsAllowedURL(t *testing.T) { + site := SiteConfig{ + BaseURL: "https://example.com", + AllowedPaths: []string{"/blog", "/products"}, + ExcludePaths: []string{"/admin", "/private"}, + } + + tests := []struct { + url string + expected bool + }{ + {"https://example.com/blog/post1", true}, + {"https://example.com/products/item1", true}, + {"https://example.com/admin/dashboard", false}, + {"https://example.com/private/data", false}, + {"https://example.com/other/page", false}, + {"https://othersite.com/blog/post1", false}, + } + + for _, test := range tests { + result := isAllowedURL(test.url, site) + if result != test.expected { + t.Errorf("isAllowedURL(%q) = %v, want %v", test.url, result, test.expected) + } + } +} + +func TestGetOverrides(t *testing.T) { + site := SiteConfig{ + CSSLocator: "main", + ExcludeSelectors: []string{".ads"}, + PathOverrides: []PathOverride{ + { + Path: "/special", + CSSLocator: ".special-content", + ExcludeSelectors: []string{".sidebar"}, + }, + }, + } + + tests := []struct { + url string + expectedLocator string + expectedExcludes []string + }{ + {"https://example.com/normal", "main", []string{".ads"}}, + {"https://example.com/special", ".special-content", []string{".sidebar"}}, + {"https://example.com/special/page", ".special-content", []string{".sidebar"}}, + } + + for _, test := range tests { + locator, excludes := getOverrides(test.url, site) + if locator != test.expectedLocator { + t.Errorf("getOverrides(%q) locator = %q, want %q", test.url, locator, test.expectedLocator) + } + if !reflect.DeepEqual(excludes, test.expectedExcludes) { + t.Errorf("getOverrides(%q) excludes = %v, want %v", test.url, excludes, test.expectedExcludes) + } + } +} + +func TestExtractContentWithCSS(t *testing.T) { + html := ` + + +
+

Main Content

+

This is the main content.

+
Advertisement
+
+ + + + ` + + tests := []struct { + includeSelector string + excludeSelectors []string + expected string + }{ + {"main", nil, "

Main Content

\n

This is the main content.

\n
Advertisement
"}, + {"main", []string{".ads"}, "

Main Content

\n

This is the main content.

"}, + {"aside", nil, "Sidebar content"}, + } + + for _, test := range tests { + result, err := ExtractContentWithCSS(html, test.includeSelector, test.excludeSelectors) + if err != nil { + t.Errorf("ExtractContentWithCSS() returned error: %v", err) + continue + } + if strings.TrimSpace(result) != strings.TrimSpace(test.expected) { + t.Errorf("ExtractContentWithCSS() = %q, want %q", result, test.expected) + } + } +} + +func TestProcessHTMLContent(t *testing.T) { + html := ` + + +

Test Heading

+

This is a test paragraph.

+ + + + ` + + expected := strings.TrimSpace(` +# Test Heading + +This is a **test** paragraph. + +- Item 1 +- Item 2 + `) + + result, err := ProcessHTMLContent(html, Config{}) + if err != nil { + t.Fatalf("ProcessHTMLContent() returned error: %v", err) + } + + if strings.TrimSpace(result) != expected { + t.Errorf("ProcessHTMLContent() = %q, want %q", result, expected) + } +} + +func TestExtractLinks(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Write([]byte(` + + + Page 1 + Page 2 + Other Site + + + `)) + })) + defer server.Close() + + links, err := ExtractLinks(server.URL) + if err != nil { + t.Fatalf("ExtractLinks() returned error: %v", err) + } + + expectedLinks := []string{ + "https://example.com/page1", + "https://example.com/page2", + "https://othersite.com", + } + + if !reflect.DeepEqual(links, expectedLinks) { + t.Errorf("ExtractLinks() = %v, want %v", links, expectedLinks) + } +}