[chore] migrate oauth2 -> codeberg (#3857)

This commit is contained in:
tobi
2025-03-02 16:42:51 +01:00
committed by GitHub
parent 49c12636c6
commit 8488ac9286
65 changed files with 1677 additions and 1221 deletions

View File

@ -14,8 +14,34 @@ This package provides various compression algorithms.
[![Go](https://github.com/klauspost/compress/actions/workflows/go.yml/badge.svg)](https://github.com/klauspost/compress/actions/workflows/go.yml)
[![Sourcegraph Badge](https://sourcegraph.com/github.com/klauspost/compress/-/badge.svg)](https://sourcegraph.com/github.com/klauspost/compress?badge)
# package usage
Use `go get github.com/klauspost/compress@latest` to add it to your project.
This package will support the current Go version and 2 versions back.
* Use the `nounsafe` tag to disable all use of the "unsafe" package.
* Use the `noasm` tag to disable all assembly across packages.
Use the links above for more information on each.
# changelog
* Feb 19th, 2025 - [1.18.0](https://github.com/klauspost/compress/releases/tag/v1.18.0)
* Add unsafe little endian loaders https://github.com/klauspost/compress/pull/1036
* fix: check `r.err != nil` but return a nil value error `err` by @alingse in https://github.com/klauspost/compress/pull/1028
* flate: Simplify L4-6 loading https://github.com/klauspost/compress/pull/1043
* flate: Simplify matchlen (remove asm) https://github.com/klauspost/compress/pull/1045
* s2: Improve small block compression speed w/o asm https://github.com/klauspost/compress/pull/1048
* flate: Fix matchlen L5+L6 https://github.com/klauspost/compress/pull/1049
* flate: Cleanup & reduce casts https://github.com/klauspost/compress/pull/1050
* Oct 11th, 2024 - [1.17.11](https://github.com/klauspost/compress/releases/tag/v1.17.11)
* zstd: Fix extra CRC written with multiple Close calls https://github.com/klauspost/compress/pull/1017
* s2: Don't use stack for index tables https://github.com/klauspost/compress/pull/1014
* gzhttp: No content-type on no body response code by @juliens in https://github.com/klauspost/compress/pull/1011
* gzhttp: Do not set the content-type when response has no body by @kevinpollet in https://github.com/klauspost/compress/pull/1013
* Sep 23rd, 2024 - [1.17.10](https://github.com/klauspost/compress/releases/tag/v1.17.10)
* gzhttp: Add TransportAlwaysDecompress option. https://github.com/klauspost/compress/pull/978
* gzhttp: Add supported decompress request body by @mirecl in https://github.com/klauspost/compress/pull/1002
@ -65,9 +91,9 @@ https://github.com/klauspost/compress/pull/919 https://github.com/klauspost/comp
* zstd: Fix rare *CORRUPTION* output in "best" mode. See https://github.com/klauspost/compress/pull/876
* Oct 14th, 2023 - [v1.17.1](https://github.com/klauspost/compress/releases/tag/v1.17.1)
* s2: Fix S2 "best" dictionary wrong encoding by @klauspost in https://github.com/klauspost/compress/pull/871
* s2: Fix S2 "best" dictionary wrong encoding https://github.com/klauspost/compress/pull/871
* flate: Reduce allocations in decompressor and minor code improvements by @fakefloordiv in https://github.com/klauspost/compress/pull/869
* s2: Fix EstimateBlockSize on 6&7 length input by @klauspost in https://github.com/klauspost/compress/pull/867
* s2: Fix EstimateBlockSize on 6&7 length input https://github.com/klauspost/compress/pull/867
* Sept 19th, 2023 - [v1.17.0](https://github.com/klauspost/compress/releases/tag/v1.17.0)
* Add experimental dictionary builder https://github.com/klauspost/compress/pull/853
@ -124,7 +150,7 @@ https://github.com/klauspost/compress/pull/919 https://github.com/klauspost/comp
<summary>See changes to v1.15.x</summary>
* Jan 21st, 2023 (v1.15.15)
* deflate: Improve level 7-9 by @klauspost in https://github.com/klauspost/compress/pull/739
* deflate: Improve level 7-9 https://github.com/klauspost/compress/pull/739
* zstd: Add delta encoding support by @greatroar in https://github.com/klauspost/compress/pull/728
* zstd: Various speed improvements by @greatroar https://github.com/klauspost/compress/pull/741 https://github.com/klauspost/compress/pull/734 https://github.com/klauspost/compress/pull/736 https://github.com/klauspost/compress/pull/744 https://github.com/klauspost/compress/pull/743 https://github.com/klauspost/compress/pull/745
* gzhttp: Add SuffixETag() and DropETag() options to prevent ETag collisions on compressed responses by @willbicks in https://github.com/klauspost/compress/pull/740
@ -167,7 +193,7 @@ https://github.com/klauspost/compress/pull/919 https://github.com/klauspost/comp
* zstd: Fix decoder crash on amd64 (no BMI) on invalid input https://github.com/klauspost/compress/pull/645
* zstd: Disable decoder extended memory copies (amd64) due to possible crashes https://github.com/klauspost/compress/pull/644
* zstd: Allow single segments up to "max decoded size" by @klauspost in https://github.com/klauspost/compress/pull/643
* zstd: Allow single segments up to "max decoded size" https://github.com/klauspost/compress/pull/643
* July 13, 2022 (v1.15.8)
@ -209,7 +235,7 @@ https://github.com/klauspost/compress/pull/919 https://github.com/klauspost/comp
* zstd: Speed up when WithDecoderLowmem(false) https://github.com/klauspost/compress/pull/599
* zstd: faster next state update in BMI2 version of decode by @WojciechMula in https://github.com/klauspost/compress/pull/593
* huff0: Do not check max size when reading table. https://github.com/klauspost/compress/pull/586
* flate: Inplace hashing for level 7-9 by @klauspost in https://github.com/klauspost/compress/pull/590
* flate: Inplace hashing for level 7-9 https://github.com/klauspost/compress/pull/590
* May 11, 2022 (v1.15.4)
@ -236,12 +262,12 @@ https://github.com/klauspost/compress/pull/919 https://github.com/klauspost/comp
* zstd: Add stricter block size checks in [#523](https://github.com/klauspost/compress/pull/523)
* Mar 3, 2022 (v1.15.0)
* zstd: Refactor decoder by @klauspost in [#498](https://github.com/klauspost/compress/pull/498)
* zstd: Add stream encoding without goroutines by @klauspost in [#505](https://github.com/klauspost/compress/pull/505)
* zstd: Refactor decoder [#498](https://github.com/klauspost/compress/pull/498)
* zstd: Add stream encoding without goroutines [#505](https://github.com/klauspost/compress/pull/505)
* huff0: Prevent single blocks exceeding 16 bits by @klauspost in[#507](https://github.com/klauspost/compress/pull/507)
* flate: Inline literal emission by @klauspost in [#509](https://github.com/klauspost/compress/pull/509)
* gzhttp: Add zstd to transport by @klauspost in [#400](https://github.com/klauspost/compress/pull/400)
* gzhttp: Make content-type optional by @klauspost in [#510](https://github.com/klauspost/compress/pull/510)
* flate: Inline literal emission [#509](https://github.com/klauspost/compress/pull/509)
* gzhttp: Add zstd to transport [#400](https://github.com/klauspost/compress/pull/400)
* gzhttp: Make content-type optional [#510](https://github.com/klauspost/compress/pull/510)
Both compression and decompression now supports "synchronous" stream operations. This means that whenever "concurrency" is set to 1, they will operate without spawning goroutines.
@ -258,7 +284,7 @@ While the release has been extensively tested, it is recommended to testing when
* flate: Fix rare huffman only (-2) corruption. [#503](https://github.com/klauspost/compress/pull/503)
* zip: Update deprecated CreateHeaderRaw to correctly call CreateRaw by @saracen in [#502](https://github.com/klauspost/compress/pull/502)
* zip: don't read data descriptor early by @saracen in [#501](https://github.com/klauspost/compress/pull/501) #501
* huff0: Use static decompression buffer up to 30% faster by @klauspost in [#499](https://github.com/klauspost/compress/pull/499) [#500](https://github.com/klauspost/compress/pull/500)
* huff0: Use static decompression buffer up to 30% faster [#499](https://github.com/klauspost/compress/pull/499) [#500](https://github.com/klauspost/compress/pull/500)
* Feb 17, 2022 (v1.14.3)
* flate: Improve fastest levels compression speed ~10% more throughput. [#482](https://github.com/klauspost/compress/pull/482) [#489](https://github.com/klauspost/compress/pull/489) [#490](https://github.com/klauspost/compress/pull/490) [#491](https://github.com/klauspost/compress/pull/491) [#494](https://github.com/klauspost/compress/pull/494) [#478](https://github.com/klauspost/compress/pull/478)
@ -565,12 +591,14 @@ While the release has been extensively tested, it is recommended to testing when
The packages are drop-in replacements for standard libraries. Simply replace the import path to use them:
| old import | new import | Documentation
|--------------------|-----------------------------------------|--------------------|
| `compress/gzip` | `github.com/klauspost/compress/gzip` | [gzip](https://pkg.go.dev/github.com/klauspost/compress/gzip?tab=doc)
| `compress/zlib` | `github.com/klauspost/compress/zlib` | [zlib](https://pkg.go.dev/github.com/klauspost/compress/zlib?tab=doc)
| `archive/zip` | `github.com/klauspost/compress/zip` | [zip](https://pkg.go.dev/github.com/klauspost/compress/zip?tab=doc)
| `compress/flate` | `github.com/klauspost/compress/flate` | [flate](https://pkg.go.dev/github.com/klauspost/compress/flate?tab=doc)
Typical speed is about 2x of the standard library packages.
| old import | new import | Documentation |
|------------------|---------------------------------------|-------------------------------------------------------------------------|
| `compress/gzip` | `github.com/klauspost/compress/gzip` | [gzip](https://pkg.go.dev/github.com/klauspost/compress/gzip?tab=doc) |
| `compress/zlib` | `github.com/klauspost/compress/zlib` | [zlib](https://pkg.go.dev/github.com/klauspost/compress/zlib?tab=doc) |
| `archive/zip` | `github.com/klauspost/compress/zip` | [zip](https://pkg.go.dev/github.com/klauspost/compress/zip?tab=doc) |
| `compress/flate` | `github.com/klauspost/compress/flate` | [flate](https://pkg.go.dev/github.com/klauspost/compress/flate?tab=doc) |
* Optimized [deflate](https://godoc.org/github.com/klauspost/compress/flate) packages which can be used as a dropin replacement for [gzip](https://godoc.org/github.com/klauspost/compress/gzip), [zip](https://godoc.org/github.com/klauspost/compress/zip) and [zlib](https://godoc.org/github.com/klauspost/compress/zlib).
@ -625,84 +653,6 @@ This will only use up to 4KB in memory when the writer is idle.
Compression is almost always worse than the fastest compression level
and each write will allocate (a little) memory.
# Performance Update 2018
It has been a while since we have been looking at the speed of this package compared to the standard library, so I thought I would re-do my tests and give some overall recommendations based on the current state. All benchmarks have been performed with Go 1.10 on my Desktop Intel(R) Core(TM) i7-2600 CPU @3.40GHz. Since I last ran the tests, I have gotten more RAM, which means tests with big files are no longer limited by my SSD.
The raw results are in my [updated spreadsheet](https://docs.google.com/spreadsheets/d/1nuNE2nPfuINCZJRMt6wFWhKpToF95I47XjSsc-1rbPQ/edit?usp=sharing). Due to cgo changes and upstream updates i could not get the cgo version of gzip to compile. Instead I included the [zstd](https://github.com/datadog/zstd) cgo implementation. If I get cgo gzip to work again, I might replace the results in the sheet.
The columns to take note of are: *MB/s* - the throughput. *Reduction* - the data size reduction in percent of the original. *Rel Speed* relative speed compared to the standard library at the same level. *Smaller* - how many percent smaller is the compressed output compared to stdlib. Negative means the output was bigger. *Loss* means the loss (or gain) in compression as a percentage difference of the input.
The `gzstd` (standard library gzip) and `gzkp` (this package gzip) only uses one CPU core. [`pgzip`](https://github.com/klauspost/pgzip), [`bgzf`](https://github.com/biogo/hts/tree/master/bgzf) uses all 4 cores. [`zstd`](https://github.com/DataDog/zstd) uses one core, and is a beast (but not Go, yet).
## Overall differences.
There appears to be a roughly 5-10% speed advantage over the standard library when comparing at similar compression levels.
The biggest difference you will see is the result of [re-balancing](https://blog.klauspost.com/rebalancing-deflate-compression-levels/) the compression levels. I wanted by library to give a smoother transition between the compression levels than the standard library.
This package attempts to provide a more smooth transition, where "1" is taking a lot of shortcuts, "5" is the reasonable trade-off and "9" is the "give me the best compression", and the values in between gives something reasonable in between. The standard library has big differences in levels 1-4, but levels 5-9 having no significant gains - often spending a lot more time than can be justified by the achieved compression.
There are links to all the test data in the [spreadsheet](https://docs.google.com/spreadsheets/d/1nuNE2nPfuINCZJRMt6wFWhKpToF95I47XjSsc-1rbPQ/edit?usp=sharing) in the top left field on each tab.
## Web Content
This test set aims to emulate typical use in a web server. The test-set is 4GB data in 53k files, and is a mixture of (mostly) HTML, JS, CSS.
Since level 1 and 9 are close to being the same code, they are quite close. But looking at the levels in-between the differences are quite big.
Looking at level 6, this package is 88% faster, but will output about 6% more data. For a web server, this means you can serve 88% more data, but have to pay for 6% more bandwidth. You can draw your own conclusions on what would be the most expensive for your case.
## Object files
This test is for typical data files stored on a server. In this case it is a collection of Go precompiled objects. They are very compressible.
The picture is similar to the web content, but with small differences since this is very compressible. Levels 2-3 offer good speed, but is sacrificing quite a bit of compression.
The standard library seems suboptimal on level 3 and 4 - offering both worse compression and speed than level 6 & 7 of this package respectively.
## Highly Compressible File
This is a JSON file with very high redundancy. The reduction starts at 95% on level 1, so in real life terms we are dealing with something like a highly redundant stream of data, etc.
It is definitely visible that we are dealing with specialized content here, so the results are very scattered. This package does not do very well at levels 1-4, but picks up significantly at level 5 and levels 7 and 8 offering great speed for the achieved compression.
So if you know you content is extremely compressible you might want to go slightly higher than the defaults. The standard library has a huge gap between levels 3 and 4 in terms of speed (2.75x slowdown), so it offers little "middle ground".
## Medium-High Compressible
This is a pretty common test corpus: [enwik9](http://mattmahoney.net/dc/textdata.html). It contains the first 10^9 bytes of the English Wikipedia dump on Mar. 3, 2006. This is a very good test of typical text based compression and more data heavy streams.
We see a similar picture here as in "Web Content". On equal levels some compression is sacrificed for more speed. Level 5 seems to be the best trade-off between speed and size, beating stdlib level 3 in both.
## Medium Compressible
I will combine two test sets, one [10GB file set](http://mattmahoney.net/dc/10gb.html) and a VM disk image (~8GB). Both contain different data types and represent a typical backup scenario.
The most notable thing is how quickly the standard library drops to very low compression speeds around level 5-6 without any big gains in compression. Since this type of data is fairly common, this does not seem like good behavior.
## Un-compressible Content
This is mainly a test of how good the algorithms are at detecting un-compressible input. The standard library only offers this feature with very conservative settings at level 1. Obviously there is no reason for the algorithms to try to compress input that cannot be compressed. The only downside is that it might skip some compressible data on false detections.
## Huffman only compression
This compression library adds a special compression level, named `HuffmanOnly`, which allows near linear time compression. This is done by completely disabling matching of previous data, and only reduce the number of bits to represent each character.
This means that often used characters, like 'e' and ' ' (space) in text use the fewest bits to represent, and rare characters like '¤' takes more bits to represent. For more information see [wikipedia](https://en.wikipedia.org/wiki/Huffman_coding) or this nice [video](https://youtu.be/ZdooBTdW5bM).
Since this type of compression has much less variance, the compression speed is mostly unaffected by the input data, and is usually more than *180MB/s* for a single core.
The downside is that the compression ratio is usually considerably worse than even the fastest conventional compression. The compression ratio can never be better than 8:1 (12.5%).
The linear time compression can be used as a "better than nothing" mode, where you cannot risk the encoder to slow down on some content. For comparison, the size of the "Twain" text is *233460 bytes* (+29% vs. level 1) and encode speed is 144MB/s (4.5x level 1). So in this case you trade a 30% size increase for a 4 times speedup.
For more information see my blog post on [Fast Linear Time Compression](http://blog.klauspost.com/constant-time-gzipzip-compression/).
This is implemented on Go 1.7 as "Huffman Only" mode, though not exposed for gzip.
# Other packages

View File

@ -6,10 +6,11 @@
package huff0
import (
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/klauspost/compress/internal/le"
)
// bitReader reads a bitstream in reverse.
@ -46,7 +47,7 @@ func (b *bitReaderBytes) init(in []byte) error {
return nil
}
// peekBitsFast requires that at least one bit is requested every time.
// peekByteFast requires that at least one byte is requested every time.
// There are no checks if the buffer is filled.
func (b *bitReaderBytes) peekByteFast() uint8 {
got := uint8(b.value >> 56)
@ -66,8 +67,7 @@ func (b *bitReaderBytes) fillFast() {
}
// 2 bounds checks.
v := b.in[b.off-4 : b.off]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
low := le.Load32(b.in, b.off-4)
b.value |= uint64(low) << (b.bitsRead - 32)
b.bitsRead -= 32
b.off -= 4
@ -76,7 +76,7 @@ func (b *bitReaderBytes) fillFast() {
// fillFastStart() assumes the bitReaderBytes is empty and there is at least 8 bytes to read.
func (b *bitReaderBytes) fillFastStart() {
// Do single re-slice to avoid bounds checks.
b.value = binary.LittleEndian.Uint64(b.in[b.off-8:])
b.value = le.Load64(b.in, b.off-8)
b.bitsRead = 0
b.off -= 8
}
@ -86,9 +86,8 @@ func (b *bitReaderBytes) fill() {
if b.bitsRead < 32 {
return
}
if b.off > 4 {
v := b.in[b.off-4 : b.off]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
if b.off >= 4 {
low := le.Load32(b.in, b.off-4)
b.value |= uint64(low) << (b.bitsRead - 32)
b.bitsRead -= 32
b.off -= 4
@ -175,9 +174,7 @@ func (b *bitReaderShifted) fillFast() {
return
}
// 2 bounds checks.
v := b.in[b.off-4 : b.off]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
low := le.Load32(b.in, b.off-4)
b.value |= uint64(low) << ((b.bitsRead - 32) & 63)
b.bitsRead -= 32
b.off -= 4
@ -185,8 +182,7 @@ func (b *bitReaderShifted) fillFast() {
// fillFastStart() assumes the bitReaderShifted is empty and there is at least 8 bytes to read.
func (b *bitReaderShifted) fillFastStart() {
// Do single re-slice to avoid bounds checks.
b.value = binary.LittleEndian.Uint64(b.in[b.off-8:])
b.value = le.Load64(b.in, b.off-8)
b.bitsRead = 0
b.off -= 8
}
@ -197,8 +193,7 @@ func (b *bitReaderShifted) fill() {
return
}
if b.off > 4 {
v := b.in[b.off-4 : b.off]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
low := le.Load32(b.in, b.off-4)
b.value |= uint64(low) << ((b.bitsRead - 32) & 63)
b.bitsRead -= 32
b.off -= 4

View File

@ -0,0 +1,5 @@
package le
type Indexer interface {
int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64
}

View File

@ -0,0 +1,42 @@
//go:build !(amd64 || arm64 || ppc64le || riscv64) || nounsafe || purego || appengine
package le
import (
"encoding/binary"
)
// Load8 will load from b at index i.
func Load8[I Indexer](b []byte, i I) byte {
return b[i]
}
// Load16 will load from b at index i.
func Load16[I Indexer](b []byte, i I) uint16 {
return binary.LittleEndian.Uint16(b[i:])
}
// Load32 will load from b at index i.
func Load32[I Indexer](b []byte, i I) uint32 {
return binary.LittleEndian.Uint32(b[i:])
}
// Load64 will load from b at index i.
func Load64[I Indexer](b []byte, i I) uint64 {
return binary.LittleEndian.Uint64(b[i:])
}
// Store16 will store v at b.
func Store16(b []byte, v uint16) {
binary.LittleEndian.PutUint16(b, v)
}
// Store32 will store v at b.
func Store32(b []byte, v uint32) {
binary.LittleEndian.PutUint32(b, v)
}
// Store64 will store v at b.
func Store64(b []byte, v uint64) {
binary.LittleEndian.PutUint64(b, v)
}

View File

@ -0,0 +1,55 @@
// We enable 64 bit LE platforms:
//go:build (amd64 || arm64 || ppc64le || riscv64) && !nounsafe && !purego && !appengine
package le
import (
"unsafe"
)
// Load8 will load from b at index i.
func Load8[I Indexer](b []byte, i I) byte {
//return binary.LittleEndian.Uint16(b[i:])
//return *(*uint16)(unsafe.Pointer(&b[i]))
return *(*byte)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(b)), i))
}
// Load16 will load from b at index i.
func Load16[I Indexer](b []byte, i I) uint16 {
//return binary.LittleEndian.Uint16(b[i:])
//return *(*uint16)(unsafe.Pointer(&b[i]))
return *(*uint16)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(b)), i))
}
// Load32 will load from b at index i.
func Load32[I Indexer](b []byte, i I) uint32 {
//return binary.LittleEndian.Uint32(b[i:])
//return *(*uint32)(unsafe.Pointer(&b[i]))
return *(*uint32)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(b)), i))
}
// Load64 will load from b at index i.
func Load64[I Indexer](b []byte, i I) uint64 {
//return binary.LittleEndian.Uint64(b[i:])
//return *(*uint64)(unsafe.Pointer(&b[i]))
return *(*uint64)(unsafe.Add(unsafe.Pointer(unsafe.SliceData(b)), i))
}
// Store16 will store v at b.
func Store16(b []byte, v uint16) {
//binary.LittleEndian.PutUint16(b, v)
*(*uint16)(unsafe.Pointer(unsafe.SliceData(b))) = v
}
// Store32 will store v at b.
func Store32(b []byte, v uint32) {
//binary.LittleEndian.PutUint32(b, v)
*(*uint32)(unsafe.Pointer(unsafe.SliceData(b))) = v
}
// Store64 will store v at b.
func Store64(b []byte, v uint64) {
//binary.LittleEndian.PutUint64(b, v)
*(*uint64)(unsafe.Pointer(unsafe.SliceData(b))) = v
}

View File

@ -79,7 +79,7 @@ This will take ownership of the buffer until the stream is closed.
func EncodeStream(src []byte, dst io.Writer) error {
enc := s2.NewWriter(dst)
// The encoder owns the buffer until Flush or Close is called.
err := enc.EncodeBuffer(buf)
err := enc.EncodeBuffer(src)
if err != nil {
enc.Close()
return err

View File

@ -11,6 +11,8 @@ package s2
import (
"fmt"
"strconv"
"github.com/klauspost/compress/internal/le"
)
// decode writes the decoding of src to dst. It assumes that the varint-encoded
@ -38,21 +40,18 @@ func s2Decode(dst, src []byte) int {
case x < 60:
s++
case x == 60:
x = uint32(src[s+1])
s += 2
x = uint32(src[s-1])
case x == 61:
in := src[s : s+3]
x = uint32(in[1]) | uint32(in[2])<<8
x = uint32(le.Load16(src, s+1))
s += 3
case x == 62:
in := src[s : s+4]
// Load as 32 bit and shift down.
x = uint32(in[0]) | uint32(in[1])<<8 | uint32(in[2])<<16 | uint32(in[3])<<24
x = le.Load32(src, s)
x >>= 8
s += 4
case x == 63:
in := src[s : s+5]
x = uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24
x = le.Load32(src, s+1)
s += 5
}
length = int(x) + 1
@ -85,8 +84,7 @@ func s2Decode(dst, src []byte) int {
length = int(src[s]) + 4
s += 1
case 6:
in := src[s : s+2]
length = int(uint32(in[0])|(uint32(in[1])<<8)) + (1 << 8)
length = int(le.Load16(src, s)) + 1<<8
s += 2
case 7:
in := src[s : s+3]
@ -99,15 +97,13 @@ func s2Decode(dst, src []byte) int {
}
length += 4
case tagCopy2:
in := src[s : s+3]
offset = int(uint32(in[1]) | uint32(in[2])<<8)
length = 1 + int(in[0])>>2
offset = int(le.Load16(src, s+1))
length = 1 + int(src[s])>>2
s += 3
case tagCopy4:
in := src[s : s+5]
offset = int(uint32(in[1]) | uint32(in[2])<<8 | uint32(in[3])<<16 | uint32(in[4])<<24)
length = 1 + int(in[0])>>2
offset = int(le.Load32(src, s+1))
length = 1 + int(src[s])>>2
s += 5
}

View File

@ -10,14 +10,16 @@ import (
"encoding/binary"
"fmt"
"math/bits"
"github.com/klauspost/compress/internal/le"
)
func load32(b []byte, i int) uint32 {
return binary.LittleEndian.Uint32(b[i:])
return le.Load32(b, i)
}
func load64(b []byte, i int) uint64 {
return binary.LittleEndian.Uint64(b[i:])
return le.Load64(b, i)
}
// hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits.
@ -44,7 +46,12 @@ func encodeGo(dst, src []byte) []byte {
d += emitLiteral(dst[d:], src)
return dst[:d]
}
n := encodeBlockGo(dst[d:], src)
var n int
if len(src) < 64<<10 {
n = encodeBlockGo64K(dst[d:], src)
} else {
n = encodeBlockGo(dst[d:], src)
}
if n > 0 {
d += n
return dst[:d]
@ -70,7 +77,6 @@ func encodeBlockGo(dst, src []byte) (d int) {
debug = false
)
var table [maxTableSize]uint32
// sLimit is when to stop looking for offset/length copies. The inputMargin
@ -277,13 +283,228 @@ emitRemainder:
return d
}
// encodeBlockGo64K is a specialized version for compressing blocks <= 64KB
func encodeBlockGo64K(dst, src []byte) (d int) {
// Initialize the hash table.
const (
tableBits = 14
maxTableSize = 1 << tableBits
debug = false
)
var table [maxTableSize]uint16
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
// Bail if we can't compress to at least this.
dstLimit := len(src) - len(src)>>5 - 5
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
cv := load64(src, s)
// We search for a repeat at -1, but don't output repeats when nextEmit == 0
repeat := 1
for {
candidate := 0
for {
// Next src position to check
nextS := s + (s-nextEmit)>>5 + 4
if nextS > sLimit {
goto emitRemainder
}
hash0 := hash6(cv, tableBits)
hash1 := hash6(cv>>8, tableBits)
candidate = int(table[hash0])
candidate2 := int(table[hash1])
table[hash0] = uint16(s)
table[hash1] = uint16(s + 1)
hash2 := hash6(cv>>16, tableBits)
// Check repeat at offset checkRep.
const checkRep = 1
if uint32(cv>>(checkRep*8)) == load32(src, s-repeat+checkRep) {
base := s + checkRep
// Extend back
for i := base - repeat; base > nextEmit && i > 0 && src[i-1] == src[base-1]; {
i--
base--
}
// Bail if we exceed the maximum size.
if d+(base-nextEmit) > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:base])
// Extend forward
candidate := s - repeat + 4 + checkRep
s += 4 + checkRep
for s <= sLimit {
if diff := load64(src, s) ^ load64(src, candidate); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidate += 8
}
if debug {
// Validate match.
if s <= candidate {
panic("s <= candidate")
}
a := src[base:s]
b := src[base-repeat : base-repeat+(s-base)]
if !bytes.Equal(a, b) {
panic("mismatch")
}
}
if nextEmit > 0 {
// same as `add := emitCopy(dst[d:], repeat, s-base)` but skips storing offset.
d += emitRepeat(dst[d:], repeat, s-base)
} else {
// First match, cannot be repeat.
d += emitCopy(dst[d:], repeat, s-base)
}
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
cv = load64(src, s)
continue
}
if uint32(cv) == load32(src, candidate) {
break
}
candidate = int(table[hash2])
if uint32(cv>>8) == load32(src, candidate2) {
table[hash2] = uint16(s + 2)
candidate = candidate2
s++
break
}
table[hash2] = uint16(s + 2)
if uint32(cv>>16) == load32(src, candidate) {
s += 2
break
}
cv = load64(src, nextS)
s = nextS
}
// Extend backwards.
// The top bytes will be rechecked to get the full match.
for candidate > 0 && s > nextEmit && src[candidate-1] == src[s-1] {
candidate--
s--
}
// Bail if we exceed the maximum size.
if d+(s-nextEmit) > dstLimit {
return 0
}
// A 4-byte match has been found. We'll later see if more than 4 bytes
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
// them as literal bytes.
d += emitLiteral(dst[d:], src[nextEmit:s])
// Call emitCopy, and then see if another emitCopy could be our next
// move. Repeat until we find no match for the input immediately after
// what was consumed by the last emitCopy call.
//
// If we exit this loop normally then we need to call emitLiteral next,
// though we don't yet know how big the literal will be. We handle that
// by proceeding to the next iteration of the main loop. We also can
// exit this loop via goto if we get close to exhausting the input.
for {
// Invariant: we have a 4-byte match at s, and no need to emit any
// literal bytes prior to s.
base := s
repeat = base - candidate
// Extend the 4-byte match as long as possible.
s += 4
candidate += 4
for s <= len(src)-8 {
if diff := load64(src, s) ^ load64(src, candidate); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidate += 8
}
d += emitCopy(dst[d:], repeat, s-base)
if debug {
// Validate match.
if s <= candidate {
panic("s <= candidate")
}
a := src[base:s]
b := src[base-repeat : base-repeat+(s-base)]
if !bytes.Equal(a, b) {
panic("mismatch")
}
}
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
if d > dstLimit {
// Do we have space for more, if not bail.
return 0
}
// Check for an immediate match, otherwise start search at s+1
x := load64(src, s-2)
m2Hash := hash6(x, tableBits)
currHash := hash6(x>>16, tableBits)
candidate = int(table[currHash])
table[m2Hash] = uint16(s - 2)
table[currHash] = uint16(s)
if debug && s == candidate {
panic("s == candidate")
}
if uint32(x>>16) != load32(src, candidate) {
cv = load64(src, s+1)
s++
break
}
}
}
emitRemainder:
if nextEmit < len(src) {
// Bail if we exceed the maximum size.
if d+len(src)-nextEmit > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}
func encodeBlockSnappyGo(dst, src []byte) (d int) {
// Initialize the hash table.
const (
tableBits = 14
maxTableSize = 1 << tableBits
)
var table [maxTableSize]uint32
// sLimit is when to stop looking for offset/length copies. The inputMargin
@ -467,6 +688,197 @@ emitRemainder:
return d
}
// encodeBlockSnappyGo64K is a special version of encodeBlockSnappyGo for sizes <64KB
func encodeBlockSnappyGo64K(dst, src []byte) (d int) {
// Initialize the hash table.
const (
tableBits = 14
maxTableSize = 1 << tableBits
)
var table [maxTableSize]uint16
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
// Bail if we can't compress to at least this.
dstLimit := len(src) - len(src)>>5 - 5
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
cv := load64(src, s)
// We search for a repeat at -1, but don't output repeats when nextEmit == 0
repeat := 1
for {
candidate := 0
for {
// Next src position to check
nextS := s + (s-nextEmit)>>5 + 4
if nextS > sLimit {
goto emitRemainder
}
hash0 := hash6(cv, tableBits)
hash1 := hash6(cv>>8, tableBits)
candidate = int(table[hash0])
candidate2 := int(table[hash1])
table[hash0] = uint16(s)
table[hash1] = uint16(s + 1)
hash2 := hash6(cv>>16, tableBits)
// Check repeat at offset checkRep.
const checkRep = 1
if uint32(cv>>(checkRep*8)) == load32(src, s-repeat+checkRep) {
base := s + checkRep
// Extend back
for i := base - repeat; base > nextEmit && i > 0 && src[i-1] == src[base-1]; {
i--
base--
}
// Bail if we exceed the maximum size.
if d+(base-nextEmit) > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:base])
// Extend forward
candidate := s - repeat + 4 + checkRep
s += 4 + checkRep
for s <= sLimit {
if diff := load64(src, s) ^ load64(src, candidate); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidate += 8
}
d += emitCopyNoRepeat(dst[d:], repeat, s-base)
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
cv = load64(src, s)
continue
}
if uint32(cv) == load32(src, candidate) {
break
}
candidate = int(table[hash2])
if uint32(cv>>8) == load32(src, candidate2) {
table[hash2] = uint16(s + 2)
candidate = candidate2
s++
break
}
table[hash2] = uint16(s + 2)
if uint32(cv>>16) == load32(src, candidate) {
s += 2
break
}
cv = load64(src, nextS)
s = nextS
}
// Extend backwards
for candidate > 0 && s > nextEmit && src[candidate-1] == src[s-1] {
candidate--
s--
}
// Bail if we exceed the maximum size.
if d+(s-nextEmit) > dstLimit {
return 0
}
// A 4-byte match has been found. We'll later see if more than 4 bytes
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
// them as literal bytes.
d += emitLiteral(dst[d:], src[nextEmit:s])
// Call emitCopy, and then see if another emitCopy could be our next
// move. Repeat until we find no match for the input immediately after
// what was consumed by the last emitCopy call.
//
// If we exit this loop normally then we need to call emitLiteral next,
// though we don't yet know how big the literal will be. We handle that
// by proceeding to the next iteration of the main loop. We also can
// exit this loop via goto if we get close to exhausting the input.
for {
// Invariant: we have a 4-byte match at s, and no need to emit any
// literal bytes prior to s.
base := s
repeat = base - candidate
// Extend the 4-byte match as long as possible.
s += 4
candidate += 4
for s <= len(src)-8 {
if diff := load64(src, s) ^ load64(src, candidate); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidate += 8
}
d += emitCopyNoRepeat(dst[d:], repeat, s-base)
if false {
// Validate match.
a := src[base:s]
b := src[base-repeat : base-repeat+(s-base)]
if !bytes.Equal(a, b) {
panic("mismatch")
}
}
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
if d > dstLimit {
// Do we have space for more, if not bail.
return 0
}
// Check for an immediate match, otherwise start search at s+1
x := load64(src, s-2)
m2Hash := hash6(x, tableBits)
currHash := hash6(x>>16, tableBits)
candidate = int(table[currHash])
table[m2Hash] = uint16(s - 2)
table[currHash] = uint16(s)
if uint32(x>>16) != load32(src, candidate) {
cv = load64(src, s+1)
s++
break
}
}
}
emitRemainder:
if nextEmit < len(src) {
// Bail if we exceed the maximum size.
if d+len(src)-nextEmit > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}
// encodeBlockGo encodes a non-empty src to a guaranteed-large-enough dst. It
// assumes that the varint-encoded length of the decompressed bytes has already
// been written.

View File

@ -348,12 +348,7 @@ func encodeBlockBetterSnappyGo(dst, src []byte) (d int) {
nextS := 0
for {
// Next src position to check
nextS = (s-nextEmit)>>7 + 1
if nextS > maxSkip {
nextS = s + maxSkip
} else {
nextS += s
}
nextS = min(s+(s-nextEmit)>>7+1, s+maxSkip)
if nextS > sLimit {
goto emitRemainder
@ -483,6 +478,415 @@ emitRemainder:
return d
}
func encodeBlockBetterGo64K(dst, src []byte) (d int) {
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
if len(src) < minNonLiteralBlockSize {
return 0
}
// Initialize the hash tables.
// Use smaller tables for smaller blocks
const (
// Long hash matches.
lTableBits = 16
maxLTableSize = 1 << lTableBits
// Short hash matches.
sTableBits = 13
maxSTableSize = 1 << sTableBits
)
var lTable [maxLTableSize]uint16
var sTable [maxSTableSize]uint16
// Bail if we can't compress to at least this.
dstLimit := len(src) - len(src)>>5 - 6
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
cv := load64(src, s)
// We initialize repeat to 0, so we never match on first attempt
repeat := 0
for {
candidateL := 0
nextS := 0
for {
// Next src position to check
nextS = s + (s-nextEmit)>>6 + 1
if nextS > sLimit {
goto emitRemainder
}
hashL := hash7(cv, lTableBits)
hashS := hash4(cv, sTableBits)
candidateL = int(lTable[hashL])
candidateS := int(sTable[hashS])
lTable[hashL] = uint16(s)
sTable[hashS] = uint16(s)
valLong := load64(src, candidateL)
valShort := load64(src, candidateS)
// If long matches at least 8 bytes, use that.
if cv == valLong {
break
}
if cv == valShort {
candidateL = candidateS
break
}
// Check repeat at offset checkRep.
const checkRep = 1
// Minimum length of a repeat. Tested with various values.
// While 4-5 offers improvements in some, 6 reduces
// regressions significantly.
const wantRepeatBytes = 6
const repeatMask = ((1 << (wantRepeatBytes * 8)) - 1) << (8 * checkRep)
if false && repeat > 0 && cv&repeatMask == load64(src, s-repeat)&repeatMask {
base := s + checkRep
// Extend back
for i := base - repeat; base > nextEmit && i > 0 && src[i-1] == src[base-1]; {
i--
base--
}
d += emitLiteral(dst[d:], src[nextEmit:base])
// Extend forward
candidate := s - repeat + wantRepeatBytes + checkRep
s += wantRepeatBytes + checkRep
for s < len(src) {
if len(src)-s < 8 {
if src[s] == src[candidate] {
s++
candidate++
continue
}
break
}
if diff := load64(src, s) ^ load64(src, candidate); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidate += 8
}
// same as `add := emitCopy(dst[d:], repeat, s-base)` but skips storing offset.
d += emitRepeat(dst[d:], repeat, s-base)
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
// Index in-between
index0 := base + 1
index1 := s - 2
for index0 < index1 {
cv0 := load64(src, index0)
cv1 := load64(src, index1)
lTable[hash7(cv0, lTableBits)] = uint16(index0)
sTable[hash4(cv0>>8, sTableBits)] = uint16(index0 + 1)
lTable[hash7(cv1, lTableBits)] = uint16(index1)
sTable[hash4(cv1>>8, sTableBits)] = uint16(index1 + 1)
index0 += 2
index1 -= 2
}
cv = load64(src, s)
continue
}
// Long likely matches 7, so take that.
if uint32(cv) == uint32(valLong) {
break
}
// Check our short candidate
if uint32(cv) == uint32(valShort) {
// Try a long candidate at s+1
hashL = hash7(cv>>8, lTableBits)
candidateL = int(lTable[hashL])
lTable[hashL] = uint16(s + 1)
if uint32(cv>>8) == load32(src, candidateL) {
s++
break
}
// Use our short candidate.
candidateL = candidateS
break
}
cv = load64(src, nextS)
s = nextS
}
// Extend backwards
for candidateL > 0 && s > nextEmit && src[candidateL-1] == src[s-1] {
candidateL--
s--
}
// Bail if we exceed the maximum size.
if d+(s-nextEmit) > dstLimit {
return 0
}
base := s
offset := base - candidateL
// Extend the 4-byte match as long as possible.
s += 4
candidateL += 4
for s < len(src) {
if len(src)-s < 8 {
if src[s] == src[candidateL] {
s++
candidateL++
continue
}
break
}
if diff := load64(src, s) ^ load64(src, candidateL); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidateL += 8
}
d += emitLiteral(dst[d:], src[nextEmit:base])
if repeat == offset {
d += emitRepeat(dst[d:], offset, s-base)
} else {
d += emitCopy(dst[d:], offset, s-base)
repeat = offset
}
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
if d > dstLimit {
// Do we have space for more, if not bail.
return 0
}
// Index short & long
index0 := base + 1
index1 := s - 2
cv0 := load64(src, index0)
cv1 := load64(src, index1)
lTable[hash7(cv0, lTableBits)] = uint16(index0)
sTable[hash4(cv0>>8, sTableBits)] = uint16(index0 + 1)
// lTable could be postponed, but very minor difference.
lTable[hash7(cv1, lTableBits)] = uint16(index1)
sTable[hash4(cv1>>8, sTableBits)] = uint16(index1 + 1)
index0 += 1
index1 -= 1
cv = load64(src, s)
// Index large values sparsely in between.
// We do two starting from different offsets for speed.
index2 := (index0 + index1 + 1) >> 1
for index2 < index1 {
lTable[hash7(load64(src, index0), lTableBits)] = uint16(index0)
lTable[hash7(load64(src, index2), lTableBits)] = uint16(index2)
index0 += 2
index2 += 2
}
}
emitRemainder:
if nextEmit < len(src) {
// Bail if we exceed the maximum size.
if d+len(src)-nextEmit > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}
// encodeBlockBetterSnappyGo encodes a non-empty src to a guaranteed-large-enough dst. It
// assumes that the varint-encoded length of the decompressed bytes has already
// been written.
//
// It also assumes that:
//
// len(dst) >= MaxEncodedLen(len(src)) &&
// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
func encodeBlockBetterSnappyGo64K(dst, src []byte) (d int) {
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
if len(src) < minNonLiteralBlockSize {
return 0
}
// Initialize the hash tables.
// Use smaller tables for smaller blocks
const (
// Long hash matches.
lTableBits = 15
maxLTableSize = 1 << lTableBits
// Short hash matches.
sTableBits = 13
maxSTableSize = 1 << sTableBits
)
var lTable [maxLTableSize]uint16
var sTable [maxSTableSize]uint16
// Bail if we can't compress to at least this.
dstLimit := len(src) - len(src)>>5 - 6
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
cv := load64(src, s)
const maxSkip = 100
for {
candidateL := 0
nextS := 0
for {
// Next src position to check
nextS = min(s+(s-nextEmit)>>6+1, s+maxSkip)
if nextS > sLimit {
goto emitRemainder
}
hashL := hash7(cv, lTableBits)
hashS := hash4(cv, sTableBits)
candidateL = int(lTable[hashL])
candidateS := int(sTable[hashS])
lTable[hashL] = uint16(s)
sTable[hashS] = uint16(s)
if uint32(cv) == load32(src, candidateL) {
break
}
// Check our short candidate
if uint32(cv) == load32(src, candidateS) {
// Try a long candidate at s+1
hashL = hash7(cv>>8, lTableBits)
candidateL = int(lTable[hashL])
lTable[hashL] = uint16(s + 1)
if uint32(cv>>8) == load32(src, candidateL) {
s++
break
}
// Use our short candidate.
candidateL = candidateS
break
}
cv = load64(src, nextS)
s = nextS
}
// Extend backwards
for candidateL > 0 && s > nextEmit && src[candidateL-1] == src[s-1] {
candidateL--
s--
}
// Bail if we exceed the maximum size.
if d+(s-nextEmit) > dstLimit {
return 0
}
base := s
offset := base - candidateL
// Extend the 4-byte match as long as possible.
s += 4
candidateL += 4
for s < len(src) {
if len(src)-s < 8 {
if src[s] == src[candidateL] {
s++
candidateL++
continue
}
break
}
if diff := load64(src, s) ^ load64(src, candidateL); diff != 0 {
s += bits.TrailingZeros64(diff) >> 3
break
}
s += 8
candidateL += 8
}
d += emitLiteral(dst[d:], src[nextEmit:base])
d += emitCopyNoRepeat(dst[d:], offset, s-base)
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
if d > dstLimit {
// Do we have space for more, if not bail.
return 0
}
// Index short & long
index0 := base + 1
index1 := s - 2
cv0 := load64(src, index0)
cv1 := load64(src, index1)
lTable[hash7(cv0, lTableBits)] = uint16(index0)
sTable[hash4(cv0>>8, sTableBits)] = uint16(index0 + 1)
lTable[hash7(cv1, lTableBits)] = uint16(index1)
sTable[hash4(cv1>>8, sTableBits)] = uint16(index1 + 1)
index0 += 1
index1 -= 1
cv = load64(src, s)
// Index large values sparsely in between.
// We do two starting from different offsets for speed.
index2 := (index0 + index1 + 1) >> 1
for index2 < index1 {
lTable[hash7(load64(src, index0), lTableBits)] = uint16(index0)
lTable[hash7(load64(src, index2), lTableBits)] = uint16(index2)
index0 += 2
index2 += 2
}
}
emitRemainder:
if nextEmit < len(src) {
// Bail if we exceed the maximum size.
if d+len(src)-nextEmit > dstLimit {
return 0
}
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}
// encodeBlockBetterDict encodes a non-empty src to a guaranteed-large-enough dst. It
// assumes that the varint-encoded length of the decompressed bytes has already
// been written.

View File

@ -21,6 +21,9 @@ func encodeBlock(dst, src []byte) (d int) {
if len(src) < minNonLiteralBlockSize {
return 0
}
if len(src) <= 64<<10 {
return encodeBlockGo64K(dst, src)
}
return encodeBlockGo(dst, src)
}
@ -32,6 +35,9 @@ func encodeBlock(dst, src []byte) (d int) {
//
// len(dst) >= MaxEncodedLen(len(src))
func encodeBlockBetter(dst, src []byte) (d int) {
if len(src) <= 64<<10 {
return encodeBlockBetterGo64K(dst, src)
}
return encodeBlockBetterGo(dst, src)
}
@ -43,6 +49,9 @@ func encodeBlockBetter(dst, src []byte) (d int) {
//
// len(dst) >= MaxEncodedLen(len(src))
func encodeBlockBetterSnappy(dst, src []byte) (d int) {
if len(src) <= 64<<10 {
return encodeBlockBetterSnappyGo64K(dst, src)
}
return encodeBlockBetterSnappyGo(dst, src)
}
@ -57,6 +66,9 @@ func encodeBlockSnappy(dst, src []byte) (d int) {
if len(src) < minNonLiteralBlockSize {
return 0
}
if len(src) <= 64<<10 {
return encodeBlockSnappyGo64K(dst, src)
}
return encodeBlockSnappyGo(dst, src)
}

View File

@ -1,4 +1,3 @@
module github.com/klauspost/compress
go 1.19
go 1.22

View File

@ -6,7 +6,7 @@ A high performance compression algorithm is implemented. For now focused on spee
This package provides [compression](#Compressor) to and [decompression](#Decompressor) of Zstandard content.
This package is pure Go and without use of "unsafe".
This package is pure Go. Use `noasm` and `nounsafe` to disable relevant features.
The `zstd` package is provided as open source software using a Go standard license.

View File

@ -5,11 +5,12 @@
package zstd
import (
"encoding/binary"
"errors"
"fmt"
"io"
"math/bits"
"github.com/klauspost/compress/internal/le"
)
// bitReader reads a bitstream in reverse.
@ -18,6 +19,7 @@ import (
type bitReader struct {
in []byte
value uint64 // Maybe use [16]byte, but shifting is awkward.
cursor int // offset where next read should end
bitsRead uint8
}
@ -32,6 +34,7 @@ func (b *bitReader) init(in []byte) error {
if v == 0 {
return errors.New("corrupt stream, did not find end of stream")
}
b.cursor = len(in)
b.bitsRead = 64
b.value = 0
if len(in) >= 8 {
@ -67,18 +70,15 @@ func (b *bitReader) fillFast() {
if b.bitsRead < 32 {
return
}
v := b.in[len(b.in)-4:]
b.in = b.in[:len(b.in)-4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
b.value = (b.value << 32) | uint64(low)
b.cursor -= 4
b.value = (b.value << 32) | uint64(le.Load32(b.in, b.cursor))
b.bitsRead -= 32
}
// fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read.
func (b *bitReader) fillFastStart() {
v := b.in[len(b.in)-8:]
b.in = b.in[:len(b.in)-8]
b.value = binary.LittleEndian.Uint64(v)
b.cursor -= 8
b.value = le.Load64(b.in, b.cursor)
b.bitsRead = 0
}
@ -87,25 +87,23 @@ func (b *bitReader) fill() {
if b.bitsRead < 32 {
return
}
if len(b.in) >= 4 {
v := b.in[len(b.in)-4:]
b.in = b.in[:len(b.in)-4]
low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
b.value = (b.value << 32) | uint64(low)
if b.cursor >= 4 {
b.cursor -= 4
b.value = (b.value << 32) | uint64(le.Load32(b.in, b.cursor))
b.bitsRead -= 32
return
}
b.bitsRead -= uint8(8 * len(b.in))
for len(b.in) > 0 {
b.value = (b.value << 8) | uint64(b.in[len(b.in)-1])
b.in = b.in[:len(b.in)-1]
b.bitsRead -= uint8(8 * b.cursor)
for b.cursor > 0 {
b.cursor -= 1
b.value = (b.value << 8) | uint64(b.in[b.cursor])
}
}
// finished returns true if all bits have been read from the bit stream.
func (b *bitReader) finished() bool {
return len(b.in) == 0 && b.bitsRead >= 64
return b.cursor == 0 && b.bitsRead >= 64
}
// overread returns true if more bits have been requested than is on the stream.
@ -115,13 +113,14 @@ func (b *bitReader) overread() bool {
// remain returns the number of bits remaining.
func (b *bitReader) remain() uint {
return 8*uint(len(b.in)) + 64 - uint(b.bitsRead)
return 8*uint(b.cursor) + 64 - uint(b.bitsRead)
}
// close the bitstream and returns an error if out-of-buffer reads occurred.
func (b *bitReader) close() error {
// Release reference.
b.in = nil
b.cursor = 0
if !b.finished() {
return fmt.Errorf("%d extra bits on block, should be 0", b.remain())
}

View File

@ -5,14 +5,10 @@
package zstd
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io"
"os"
"path/filepath"
"sync"
"github.com/klauspost/compress/huff0"
@ -648,21 +644,6 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) {
println("initializing sequences:", err)
return err
}
// Extract blocks...
if false && hist.dict == nil {
fatalErr := func(err error) {
if err != nil {
panic(err)
}
}
fn := fmt.Sprintf("n-%d-lits-%d-prev-%d-%d-%d-win-%d.blk", hist.decoders.nSeqs, len(hist.decoders.literals), hist.recentOffsets[0], hist.recentOffsets[1], hist.recentOffsets[2], hist.windowSize)
var buf bytes.Buffer
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.litLengths.fse))
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.matchLengths.fse))
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.offsets.fse))
buf.Write(in)
os.WriteFile(filepath.Join("testdata", "seqs", fn), buf.Bytes(), os.ModePerm)
}
return nil
}

View File

@ -9,6 +9,7 @@ import (
"fmt"
"math"
"math/bits"
"slices"
"github.com/klauspost/compress/huff0"
)
@ -457,16 +458,7 @@ func fuzzFseEncoder(data []byte) int {
// All 0
return 0
}
maxCount := func(a []uint32) int {
var max uint32
for _, v := range a {
if v > max {
max = v
}
}
return int(max)
}
cnt := maxCount(hist[:maxSym])
cnt := int(slices.Max(hist[:maxSym]))
if cnt == len(data) {
// RLE
return 0
@ -884,15 +876,6 @@ func (b *blockEnc) genCodes() {
}
}
}
maxCount := func(a []uint32) int {
var max uint32
for _, v := range a {
if v > max {
max = v
}
}
return int(max)
}
if debugAsserts && mlMax > maxMatchLengthSymbol {
panic(fmt.Errorf("mlMax > maxMatchLengthSymbol (%d)", mlMax))
}
@ -903,7 +886,7 @@ func (b *blockEnc) genCodes() {
panic(fmt.Errorf("llMax > maxLiteralLengthSymbol (%d)", llMax))
}
b.coders.mlEnc.HistogramFinished(mlMax, maxCount(mlH[:mlMax+1]))
b.coders.ofEnc.HistogramFinished(ofMax, maxCount(ofH[:ofMax+1]))
b.coders.llEnc.HistogramFinished(llMax, maxCount(llH[:llMax+1]))
b.coders.mlEnc.HistogramFinished(mlMax, int(slices.Max(mlH[:mlMax+1])))
b.coders.ofEnc.HistogramFinished(ofMax, int(slices.Max(ofH[:ofMax+1])))
b.coders.llEnc.HistogramFinished(llMax, int(slices.Max(llH[:llMax+1])))
}

View File

@ -123,7 +123,7 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) {
}
// Read bytes from the decompressed stream into p.
// Returns the number of bytes written and any error that occurred.
// Returns the number of bytes read and any error that occurred.
// When the stream is done, io.EOF will be returned.
func (d *Decoder) Read(p []byte) (int, error) {
var n int
@ -323,6 +323,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
frame.bBuf = nil
if frame.history.decoders.br != nil {
frame.history.decoders.br.in = nil
frame.history.decoders.br.cursor = 0
}
d.decoders <- block
}()

View File

@ -116,7 +116,7 @@ func (e *fastBase) matchlen(s, t int32, src []byte) int32 {
panic(err)
}
if t < 0 {
err := fmt.Sprintf("s (%d) < 0", s)
err := fmt.Sprintf("t (%d) < 0", t)
panic(err)
}
if s-t > e.maxMatchOff {

View File

@ -7,20 +7,25 @@
package zstd
import (
"encoding/binary"
"math/bits"
"github.com/klauspost/compress/internal/le"
)
// matchLen returns the maximum common prefix length of a and b.
// a must be the shortest of the two.
func matchLen(a, b []byte) (n int) {
for ; len(a) >= 8 && len(b) >= 8; a, b = a[8:], b[8:] {
diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b)
left := len(a)
for left >= 8 {
diff := le.Load64(a, n) ^ le.Load64(b, n)
if diff != 0 {
return n + bits.TrailingZeros64(diff)>>3
}
n += 8
left -= 8
}
a = a[n:]
b = b[n:]
for i := range a {
if a[i] != b[i] {

View File

@ -245,7 +245,7 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
return io.ErrUnexpectedEOF
}
var ll, mo, ml int
if len(br.in) > 4+((maxOffsetBits+16+16)>>3) {
if br.cursor > 4+((maxOffsetBits+16+16)>>3) {
// inlined function:
// ll, mo, ml = s.nextFast(br, llState, mlState, ofState)

View File

@ -7,9 +7,9 @@
TEXT ·sequenceDecs_decode_amd64(SB), $8-32
MOVQ br+8(FP), CX
MOVQ 24(CX), DX
MOVBQZX 32(CX), BX
MOVBQZX 40(CX), BX
MOVQ (CX), AX
MOVQ 8(CX), SI
MOVQ 32(CX), SI
ADDQ SI, AX
MOVQ AX, (SP)
MOVQ ctx+16(FP), AX
@ -299,8 +299,8 @@ sequenceDecs_decode_amd64_match_len_ofs_ok:
MOVQ R13, 160(AX)
MOVQ br+8(FP), AX
MOVQ DX, 24(AX)
MOVB BL, 32(AX)
MOVQ SI, 8(AX)
MOVB BL, 40(AX)
MOVQ SI, 32(AX)
// Return success
MOVQ $0x00000000, ret+24(FP)
@ -335,9 +335,9 @@ error_overread:
TEXT ·sequenceDecs_decode_56_amd64(SB), $8-32
MOVQ br+8(FP), CX
MOVQ 24(CX), DX
MOVBQZX 32(CX), BX
MOVBQZX 40(CX), BX
MOVQ (CX), AX
MOVQ 8(CX), SI
MOVQ 32(CX), SI
ADDQ SI, AX
MOVQ AX, (SP)
MOVQ ctx+16(FP), AX
@ -598,8 +598,8 @@ sequenceDecs_decode_56_amd64_match_len_ofs_ok:
MOVQ R13, 160(AX)
MOVQ br+8(FP), AX
MOVQ DX, 24(AX)
MOVB BL, 32(AX)
MOVQ SI, 8(AX)
MOVB BL, 40(AX)
MOVQ SI, 32(AX)
// Return success
MOVQ $0x00000000, ret+24(FP)
@ -634,9 +634,9 @@ error_overread:
TEXT ·sequenceDecs_decode_bmi2(SB), $8-32
MOVQ br+8(FP), BX
MOVQ 24(BX), AX
MOVBQZX 32(BX), DX
MOVBQZX 40(BX), DX
MOVQ (BX), CX
MOVQ 8(BX), BX
MOVQ 32(BX), BX
ADDQ BX, CX
MOVQ CX, (SP)
MOVQ ctx+16(FP), CX
@ -884,8 +884,8 @@ sequenceDecs_decode_bmi2_match_len_ofs_ok:
MOVQ R12, 160(CX)
MOVQ br+8(FP), CX
MOVQ AX, 24(CX)
MOVB DL, 32(CX)
MOVQ BX, 8(CX)
MOVB DL, 40(CX)
MOVQ BX, 32(CX)
// Return success
MOVQ $0x00000000, ret+24(FP)
@ -920,9 +920,9 @@ error_overread:
TEXT ·sequenceDecs_decode_56_bmi2(SB), $8-32
MOVQ br+8(FP), BX
MOVQ 24(BX), AX
MOVBQZX 32(BX), DX
MOVBQZX 40(BX), DX
MOVQ (BX), CX
MOVQ 8(BX), BX
MOVQ 32(BX), BX
ADDQ BX, CX
MOVQ CX, (SP)
MOVQ ctx+16(FP), CX
@ -1141,8 +1141,8 @@ sequenceDecs_decode_56_bmi2_match_len_ofs_ok:
MOVQ R12, 160(CX)
MOVQ br+8(FP), CX
MOVQ AX, 24(CX)
MOVB DL, 32(CX)
MOVQ BX, 8(CX)
MOVB DL, 40(CX)
MOVQ BX, 32(CX)
// Return success
MOVQ $0x00000000, ret+24(FP)
@ -1787,9 +1787,9 @@ empty_seqs:
TEXT ·sequenceDecs_decodeSync_amd64(SB), $64-32
MOVQ br+8(FP), CX
MOVQ 24(CX), DX
MOVBQZX 32(CX), BX
MOVBQZX 40(CX), BX
MOVQ (CX), AX
MOVQ 8(CX), SI
MOVQ 32(CX), SI
ADDQ SI, AX
MOVQ AX, (SP)
MOVQ ctx+16(FP), AX
@ -2281,8 +2281,8 @@ handle_loop:
loop_finished:
MOVQ br+8(FP), AX
MOVQ DX, 24(AX)
MOVB BL, 32(AX)
MOVQ SI, 8(AX)
MOVB BL, 40(AX)
MOVQ SI, 32(AX)
// Update the context
MOVQ ctx+16(FP), AX
@ -2349,9 +2349,9 @@ error_not_enough_space:
TEXT ·sequenceDecs_decodeSync_bmi2(SB), $64-32
MOVQ br+8(FP), BX
MOVQ 24(BX), AX
MOVBQZX 32(BX), DX
MOVBQZX 40(BX), DX
MOVQ (BX), CX
MOVQ 8(BX), BX
MOVQ 32(BX), BX
ADDQ BX, CX
MOVQ CX, (SP)
MOVQ ctx+16(FP), CX
@ -2801,8 +2801,8 @@ handle_loop:
loop_finished:
MOVQ br+8(FP), CX
MOVQ AX, 24(CX)
MOVB DL, 32(CX)
MOVQ BX, 8(CX)
MOVB DL, 40(CX)
MOVQ BX, 32(CX)
// Update the context
MOVQ ctx+16(FP), AX
@ -2869,9 +2869,9 @@ error_not_enough_space:
TEXT ·sequenceDecs_decodeSync_safe_amd64(SB), $64-32
MOVQ br+8(FP), CX
MOVQ 24(CX), DX
MOVBQZX 32(CX), BX
MOVBQZX 40(CX), BX
MOVQ (CX), AX
MOVQ 8(CX), SI
MOVQ 32(CX), SI
ADDQ SI, AX
MOVQ AX, (SP)
MOVQ ctx+16(FP), AX
@ -3465,8 +3465,8 @@ handle_loop:
loop_finished:
MOVQ br+8(FP), AX
MOVQ DX, 24(AX)
MOVB BL, 32(AX)
MOVQ SI, 8(AX)
MOVB BL, 40(AX)
MOVQ SI, 32(AX)
// Update the context
MOVQ ctx+16(FP), AX
@ -3533,9 +3533,9 @@ error_not_enough_space:
TEXT ·sequenceDecs_decodeSync_safe_bmi2(SB), $64-32
MOVQ br+8(FP), BX
MOVQ 24(BX), AX
MOVBQZX 32(BX), DX
MOVBQZX 40(BX), DX
MOVQ (BX), CX
MOVQ 8(BX), BX
MOVQ 32(BX), BX
ADDQ BX, CX
MOVQ CX, (SP)
MOVQ ctx+16(FP), CX
@ -4087,8 +4087,8 @@ handle_loop:
loop_finished:
MOVQ br+8(FP), CX
MOVQ AX, 24(CX)
MOVB DL, 32(CX)
MOVQ BX, 8(CX)
MOVB DL, 40(CX)
MOVQ BX, 32(CX)
// Update the context
MOVQ ctx+16(FP), AX

View File

@ -29,7 +29,7 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
}
for i := range seqs {
var ll, mo, ml int
if len(br.in) > 4+((maxOffsetBits+16+16)>>3) {
if br.cursor > 4+((maxOffsetBits+16+16)>>3) {
// inlined function:
// ll, mo, ml = s.nextFast(br, llState, mlState, ofState)

View File

@ -69,7 +69,6 @@ var llBitsTable = [maxLLCode + 1]byte{
func llCode(litLength uint32) uint8 {
const llDeltaCode = 19
if litLength <= 63 {
// Compiler insists on bounds check (Go 1.12)
return llCodeTable[litLength&63]
}
return uint8(highBit(litLength)) + llDeltaCode
@ -102,7 +101,6 @@ var mlBitsTable = [maxMLCode + 1]byte{
func mlCode(mlBase uint32) uint8 {
const mlDeltaCode = 36
if mlBase <= 127 {
// Compiler insists on bounds check (Go 1.12)
return mlCodeTable[mlBase&127]
}
return uint8(highBit(mlBase)) + mlDeltaCode

View File

@ -197,7 +197,7 @@ func (r *SnappyConverter) Convert(in io.Reader, w io.Writer) (int64, error) {
n, r.err = w.Write(r.block.output)
if r.err != nil {
return written, err
return written, r.err
}
written += int64(n)
continue
@ -239,7 +239,7 @@ func (r *SnappyConverter) Convert(in io.Reader, w io.Writer) (int64, error) {
}
n, r.err = w.Write(r.block.output)
if r.err != nil {
return written, err
return written, r.err
}
written += int64(n)
continue

View File

@ -5,10 +5,11 @@ package zstd
import (
"bytes"
"encoding/binary"
"errors"
"log"
"math"
"github.com/klauspost/compress/internal/le"
)
// enable debug printing
@ -110,11 +111,11 @@ func printf(format string, a ...interface{}) {
}
func load3232(b []byte, i int32) uint32 {
return binary.LittleEndian.Uint32(b[:len(b):len(b)][i:])
return le.Load32(b, i)
}
func load6432(b []byte, i int32) uint64 {
return binary.LittleEndian.Uint64(b[:len(b):len(b)][i:])
return le.Load64(b, i)
}
type byter interface {

View File

@ -1,2 +0,0 @@
[*.go]
end_of_line = crlf

View File

@ -1,33 +0,0 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
coverage.txt
# OSX
*.DS_Store
*.db
*.swp
/example/client/client
/example/server/server

View File

@ -1,13 +0,0 @@
language: go
sudo: false
go_import_path: github.com/go-oauth2/oauth2/v4
go:
- 1.13
before_install:
- go get -t -v ./...
script:
- chmod +x ./go.test.sh && ./go.test.sh
after_success:
- bash <(curl -s https://codecov.io/bash)

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2016 Lyric
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,3 +0,0 @@
# Golang OAuth 2.0 Server
Forked from [go-oauth2](https://github.com/go-oauth2/oauth2).

View File

@ -1,76 +0,0 @@
package oauth2
import (
"crypto/sha256"
"encoding/base64"
"strings"
)
// ResponseType the type of authorization request
type ResponseType string
// define the type of authorization request
const (
Code ResponseType = "code"
Token ResponseType = "token"
)
func (rt ResponseType) String() string {
return string(rt)
}
// GrantType authorization model
type GrantType string
// define authorization model
const (
AuthorizationCode GrantType = "authorization_code"
PasswordCredentials GrantType = "password"
ClientCredentials GrantType = "client_credentials"
Refreshing GrantType = "refresh_token"
Implicit GrantType = "__implicit"
)
func (gt GrantType) String() string {
if gt == AuthorizationCode ||
gt == PasswordCredentials ||
gt == ClientCredentials ||
gt == Refreshing {
return string(gt)
}
return ""
}
// CodeChallengeMethod PCKE method
type CodeChallengeMethod string
const (
// CodeChallengePlain PCKE Method
CodeChallengePlain CodeChallengeMethod = "plain"
// CodeChallengeS256 PCKE Method
CodeChallengeS256 CodeChallengeMethod = "S256"
)
func (ccm CodeChallengeMethod) String() string {
if ccm == CodeChallengePlain ||
ccm == CodeChallengeS256 {
return string(ccm)
}
return ""
}
// Validate code challenge
func (ccm CodeChallengeMethod) Validate(cc, ver string) bool {
switch ccm {
case CodeChallengePlain:
return cc == ver
case CodeChallengeS256:
s256 := sha256.Sum256([]byte(ver))
// trim padding
a := strings.TrimRight(base64.URLEncoding.EncodeToString(s256[:]), "=")
b := strings.TrimRight(cc, "=")
return a == b
default:
return false
}
}

View File

@ -1,24 +0,0 @@
// OAuth 2.0 server library for the Go programming language
//
// package main
// import (
// "net/http"
// "github.com/superseriousbusiness/oauth2/v4/manage"
// "github.com/superseriousbusiness/oauth2/v4/server"
// "github.com/superseriousbusiness/oauth2/v4/store"
// )
// func main() {
// manager := manage.NewDefaultManager()
// manager.MustTokenStorage(store.NewMemoryTokenStore())
// manager.MapClientStorage(store.NewTestClientStore())
// srv := server.NewDefaultServer(manager)
// http.HandleFunc("/authorize", func(w http.ResponseWriter, r *http.Request) {
// srv.HandleAuthorizeRequest(w, r)
// })
// http.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
// srv.HandleTokenRequest(w, r)
// })
// http.ListenAndServe(":9096", nil)
// }
package oauth2

View File

@ -1,19 +0,0 @@
package errors
import "errors"
// New returns an error that formats as the given text.
var New = errors.New
// known errors
var (
ErrInvalidRedirectURI = errors.New("invalid redirect uri")
ErrInvalidAuthorizeCode = errors.New("invalid authorize code")
ErrInvalidAccessToken = errors.New("invalid access token")
ErrInvalidRefreshToken = errors.New("invalid refresh token")
ErrExpiredAccessToken = errors.New("expired access token")
ErrExpiredRefreshToken = errors.New("expired refresh token")
ErrMissingCodeVerifier = errors.New("missing code verifier")
ErrMissingCodeChallenge = errors.New("missing code challenge")
ErrInvalidCodeChallenge = errors.New("invalid code challenge")
)

View File

@ -1,84 +0,0 @@
package errors
import (
"errors"
"net/http"
)
// Response error response
type Response struct {
Error error
ErrorCode int
Description string
URI string
StatusCode int
Header http.Header
}
// NewResponse create the response pointer
func NewResponse(err error, statusCode int) *Response {
return &Response{
Error: err,
StatusCode: statusCode,
}
}
// SetHeader sets the header entries associated with key to
// the single element value.
func (r *Response) SetHeader(key, value string) {
if r.Header == nil {
r.Header = make(http.Header)
}
r.Header.Set(key, value)
}
// https://tools.ietf.org/html/rfc6749#section-5.2
var (
ErrInvalidRequest = errors.New("invalid_request")
ErrUnauthorizedClient = errors.New("unauthorized_client")
ErrAccessDenied = errors.New("access_denied")
ErrUnsupportedResponseType = errors.New("unsupported_response_type")
ErrInvalidScope = errors.New("invalid_scope")
ErrServerError = errors.New("server_error")
ErrTemporarilyUnavailable = errors.New("temporarily_unavailable")
ErrInvalidClient = errors.New("invalid_client")
ErrInvalidGrant = errors.New("invalid_grant")
ErrUnsupportedGrantType = errors.New("unsupported_grant_type")
ErrCodeChallengeRquired = errors.New("invalid_request")
ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request")
ErrInvalidCodeChallengeLen = errors.New("invalid_request")
)
// Descriptions error description
var Descriptions = map[error]string{
ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed",
ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method",
ErrAccessDenied: "The resource owner or authorization server denied the request",
ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method",
ErrInvalidScope: "The requested scope is invalid, unknown, or malformed",
ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request",
ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server",
ErrInvalidClient: "Client authentication failed",
ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client",
ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server",
ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing",
ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported",
ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long",
}
// StatusCodes response error HTTP status code
var StatusCodes = map[error]int{
ErrInvalidRequest: 400,
ErrUnauthorizedClient: 401,
ErrAccessDenied: 403,
ErrUnsupportedResponseType: 401,
ErrInvalidScope: 400,
ErrServerError: 500,
ErrTemporarilyUnavailable: 503,
ErrInvalidClient: 401,
ErrInvalidGrant: 401,
ErrUnsupportedGrantType: 401,
ErrCodeChallengeRquired: 400,
ErrUnsupportedCodeChallengeMethod: 400,
ErrInvalidCodeChallengeLen: 400,
}

View File

@ -1,28 +0,0 @@
package oauth2
import (
"context"
"net/http"
"time"
)
type (
// GenerateBasic provide the basis of the generated token data
GenerateBasic struct {
Client ClientInfo
UserID string
CreateAt time.Time
TokenInfo TokenInfo
Request *http.Request
}
// AuthorizeGenerate generate the authorization code interface
AuthorizeGenerate interface {
Token(ctx context.Context, data *GenerateBasic) (code string, err error)
}
// AccessGenerate generate the access and refresh tokens interface
AccessGenerate interface {
Token(ctx context.Context, data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error)
}
)

View File

@ -1,38 +0,0 @@
package generates
import (
"bytes"
"context"
"encoding/base64"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/superseriousbusiness/oauth2/v4"
)
// NewAccessGenerate create to generate the access token instance
func NewAccessGenerate() *AccessGenerate {
return &AccessGenerate{}
}
// AccessGenerate generate the access token
type AccessGenerate struct {
}
// Token based on the UUID generated token
func (ag *AccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
buf := bytes.NewBufferString(data.Client.GetID())
buf.WriteString(data.UserID)
buf.WriteString(strconv.FormatInt(data.CreateAt.UnixNano(), 10))
access := base64.URLEncoding.EncodeToString([]byte(uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes()).String()))
access = strings.ToUpper(strings.TrimRight(access, "="))
refresh := ""
if isGenRefresh {
refresh = base64.URLEncoding.EncodeToString([]byte(uuid.NewSHA1(uuid.Must(uuid.NewRandom()), buf.Bytes()).String()))
refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
}
return access, refresh, nil
}

View File

@ -1,30 +0,0 @@
package generates
import (
"bytes"
"context"
"encoding/base64"
"strings"
"github.com/google/uuid"
"github.com/superseriousbusiness/oauth2/v4"
)
// NewAuthorizeGenerate create to generate the authorize code instance
func NewAuthorizeGenerate() *AuthorizeGenerate {
return &AuthorizeGenerate{}
}
// AuthorizeGenerate generate the authorize code
type AuthorizeGenerate struct{}
// Token based on the UUID generated token
func (ag *AuthorizeGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic) (string, error) {
buf := bytes.NewBufferString(data.Client.GetID())
buf.WriteString(data.UserID)
token := uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes())
code := base64.URLEncoding.EncodeToString([]byte(token.String()))
code = strings.ToUpper(strings.TrimRight(code, "="))
return code, nil
}

View File

@ -1,104 +0,0 @@
package generates
import (
"context"
"encoding/base64"
"strings"
"time"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/errors"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
)
// JWTAccessClaims jwt claims
type JWTAccessClaims struct {
jwt.StandardClaims
}
// Valid claims verification
func (a *JWTAccessClaims) Valid() error {
if time.Unix(a.ExpiresAt, 0).Before(time.Now()) {
return errors.ErrInvalidAccessToken
}
return nil
}
// NewJWTAccessGenerate create to generate the jwt access token instance
func NewJWTAccessGenerate(kid string, key []byte, method jwt.SigningMethod) *JWTAccessGenerate {
return &JWTAccessGenerate{
SignedKeyID: kid,
SignedKey: key,
SignedMethod: method,
}
}
// JWTAccessGenerate generate the jwt access token
type JWTAccessGenerate struct {
SignedKeyID string
SignedKey []byte
SignedMethod jwt.SigningMethod
}
// Token based on the UUID generated token
func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
claims := &JWTAccessClaims{
StandardClaims: jwt.StandardClaims{
Audience: data.Client.GetID(),
Subject: data.UserID,
ExpiresAt: data.TokenInfo.GetAccessCreateAt().Add(data.TokenInfo.GetAccessExpiresIn()).Unix(),
},
}
token := jwt.NewWithClaims(a.SignedMethod, claims)
if a.SignedKeyID != "" {
token.Header["kid"] = a.SignedKeyID
}
var key interface{}
if a.isEs() {
v, err := jwt.ParseECPrivateKeyFromPEM(a.SignedKey)
if err != nil {
return "", "", err
}
key = v
} else if a.isRsOrPS() {
v, err := jwt.ParseRSAPrivateKeyFromPEM(a.SignedKey)
if err != nil {
return "", "", err
}
key = v
} else if a.isHs() {
key = a.SignedKey
} else {
return "", "", errors.New("unsupported sign method")
}
access, err := token.SignedString(key)
if err != nil {
return "", "", err
}
refresh := ""
if isGenRefresh {
t := uuid.NewSHA1(uuid.Must(uuid.NewRandom()), []byte(access)).String()
refresh = base64.URLEncoding.EncodeToString([]byte(t))
refresh = strings.ToUpper(strings.TrimRight(refresh, "="))
}
return access, refresh, nil
}
func (a *JWTAccessGenerate) isEs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
}
func (a *JWTAccessGenerate) isRsOrPS() bool {
isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
return isRs || isPs
}
func (a *JWTAccessGenerate) isHs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
}

View File

@ -1,12 +0,0 @@
#!/usr/bin/env bash
set -e
echo "" > coverage.txt
for d in $(go list ./... | grep -v vendor); do
go test -race -coverprofile=profile.out -covermode=atomic "$d"
if [ -f profile.out ]; then
cat profile.out >> coverage.txt
rm profile.out
fi
done

View File

@ -1,50 +0,0 @@
package oauth2
import (
"context"
"net/http"
"time"
)
// TokenGenerateRequest provide to generate the token request parameters
type TokenGenerateRequest struct {
ClientID string
ClientSecret string
UserID string
RedirectURI string
Scope string
Code string
CodeChallenge string
CodeChallengeMethod CodeChallengeMethod
Refresh string
CodeVerifier string
AccessTokenExp time.Duration
Request *http.Request
}
// Manager authorization management interface
type Manager interface {
// get the client information
GetClient(ctx context.Context, clientID string) (cli ClientInfo, err error)
// generate the authorization token(code)
GenerateAuthToken(ctx context.Context, rt ResponseType, tgr *TokenGenerateRequest) (authToken TokenInfo, err error)
// generate the access token
GenerateAccessToken(ctx context.Context, rt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
// refreshing an access token
RefreshAccessToken(ctx context.Context, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
// use the access token to delete the token information
RemoveAccessToken(ctx context.Context, access string) (err error)
// use the refresh token to delete the token information
RemoveRefreshToken(ctx context.Context, refresh string) (err error)
// according to the access token for corresponding token information
LoadAccessToken(ctx context.Context, access string) (ti TokenInfo, err error)
// according to the refresh token for corresponding token information
LoadRefreshToken(ctx context.Context, refresh string) (ti TokenInfo, err error)
}

View File

@ -1,39 +0,0 @@
package manage
import "time"
// Config authorization configuration parameters
type Config struct {
// access token expiration time, 0 means it doesn't expire
AccessTokenExp time.Duration
// refresh token expiration time, 0 means it doesn't expire
RefreshTokenExp time.Duration
// whether to generate the refreshing token
IsGenerateRefresh bool
}
// RefreshingConfig refreshing token config
type RefreshingConfig struct {
// access token expiration time, 0 means it doesn't expire
AccessTokenExp time.Duration
// refresh token expiration time, 0 means it doesn't expire
RefreshTokenExp time.Duration
// whether to generate the refreshing token
IsGenerateRefresh bool
// whether to reset the refreshing create time
IsResetRefreshTime bool
// whether to remove access token
IsRemoveAccess bool
// whether to remove refreshing token
IsRemoveRefreshing bool
}
// default configs
var (
DefaultCodeExp = time.Minute * 10
DefaultAuthorizeCodeTokenCfg = &Config{AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 3, IsGenerateRefresh: true}
DefaultImplicitTokenCfg = &Config{AccessTokenExp: time.Hour * 1}
DefaultPasswordTokenCfg = &Config{AccessTokenExp: time.Hour * 2, RefreshTokenExp: time.Hour * 24 * 7, IsGenerateRefresh: true}
DefaultClientTokenCfg = &Config{AccessTokenExp: time.Hour * 2}
DefaultRefreshTokenCfg = &RefreshingConfig{IsGenerateRefresh: true, IsRemoveAccess: true, IsRemoveRefreshing: true}
)

View File

@ -1,504 +0,0 @@
package manage
import (
"context"
"time"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/errors"
"github.com/superseriousbusiness/oauth2/v4/generates"
"github.com/superseriousbusiness/oauth2/v4/models"
)
// NewDefaultManager create to default authorization management instance
func NewDefaultManager() *Manager {
m := NewManager()
// default implementation
m.MapAuthorizeGenerate(generates.NewAuthorizeGenerate())
m.MapAccessGenerate(generates.NewAccessGenerate())
return m
}
// NewManager create to authorization management instance
func NewManager() *Manager {
return &Manager{
gtcfg: make(map[oauth2.GrantType]*Config),
validateURI: DefaultValidateURI,
}
}
// Manager provide authorization management
type Manager struct {
codeExp time.Duration
gtcfg map[oauth2.GrantType]*Config
rcfg *RefreshingConfig
validateURI ValidateURIHandler
authorizeGenerate oauth2.AuthorizeGenerate
accessGenerate oauth2.AccessGenerate
tokenStore oauth2.TokenStore
clientStore oauth2.ClientStore
}
// get grant type config
func (m *Manager) grantConfig(gt oauth2.GrantType) *Config {
if c, ok := m.gtcfg[gt]; ok && c != nil {
return c
}
switch gt {
case oauth2.AuthorizationCode:
return DefaultAuthorizeCodeTokenCfg
case oauth2.Implicit:
return DefaultImplicitTokenCfg
case oauth2.PasswordCredentials:
return DefaultPasswordTokenCfg
case oauth2.ClientCredentials:
return DefaultClientTokenCfg
}
return &Config{}
}
// SetAuthorizeCodeExp set the authorization code expiration time
func (m *Manager) SetAuthorizeCodeExp(exp time.Duration) {
m.codeExp = exp
}
// SetAuthorizeCodeTokenCfg set the authorization code grant token config
func (m *Manager) SetAuthorizeCodeTokenCfg(cfg *Config) {
m.gtcfg[oauth2.AuthorizationCode] = cfg
}
// SetImplicitTokenCfg set the implicit grant token config
func (m *Manager) SetImplicitTokenCfg(cfg *Config) {
m.gtcfg[oauth2.Implicit] = cfg
}
// SetPasswordTokenCfg set the password grant token config
func (m *Manager) SetPasswordTokenCfg(cfg *Config) {
m.gtcfg[oauth2.PasswordCredentials] = cfg
}
// SetClientTokenCfg set the client grant token config
func (m *Manager) SetClientTokenCfg(cfg *Config) {
m.gtcfg[oauth2.ClientCredentials] = cfg
}
// SetRefreshTokenCfg set the refreshing token config
func (m *Manager) SetRefreshTokenCfg(cfg *RefreshingConfig) {
m.rcfg = cfg
}
// SetValidateURIHandler set the validates that RedirectURI is contained in baseURI
func (m *Manager) SetValidateURIHandler(handler ValidateURIHandler) {
m.validateURI = handler
}
// MapAuthorizeGenerate mapping the authorize code generate interface
func (m *Manager) MapAuthorizeGenerate(gen oauth2.AuthorizeGenerate) {
m.authorizeGenerate = gen
}
// MapAccessGenerate mapping the access token generate interface
func (m *Manager) MapAccessGenerate(gen oauth2.AccessGenerate) {
m.accessGenerate = gen
}
// MapClientStorage mapping the client store interface
func (m *Manager) MapClientStorage(stor oauth2.ClientStore) {
m.clientStore = stor
}
// MustClientStorage mandatory mapping the client store interface
func (m *Manager) MustClientStorage(stor oauth2.ClientStore, err error) {
if err != nil {
panic(err.Error())
}
m.clientStore = stor
}
// MapTokenStorage mapping the token store interface
func (m *Manager) MapTokenStorage(stor oauth2.TokenStore) {
m.tokenStore = stor
}
// MustTokenStorage mandatory mapping the token store interface
func (m *Manager) MustTokenStorage(stor oauth2.TokenStore, err error) {
if err != nil {
panic(err)
}
m.tokenStore = stor
}
// GetClient get the client information
func (m *Manager) GetClient(ctx context.Context, clientID string) (cli oauth2.ClientInfo, err error) {
cli, err = m.clientStore.GetByID(ctx, clientID)
if err != nil {
return
} else if cli == nil {
err = errors.ErrInvalidClient
}
return
}
// GenerateAuthToken generate the authorization token(code)
func (m *Manager) GenerateAuthToken(ctx context.Context, rt oauth2.ResponseType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
cli, err := m.GetClient(ctx, tgr.ClientID)
if err != nil {
return nil, err
} else if tgr.RedirectURI != "" {
if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
return nil, err
}
}
ti := models.NewToken()
ti.SetClientID(tgr.ClientID)
ti.SetUserID(tgr.UserID)
ti.SetRedirectURI(tgr.RedirectURI)
ti.SetScope(tgr.Scope)
createAt := time.Now()
td := &oauth2.GenerateBasic{
Client: cli,
UserID: tgr.UserID,
CreateAt: createAt,
TokenInfo: ti,
Request: tgr.Request,
}
switch rt {
case oauth2.Code:
codeExp := m.codeExp
if codeExp == 0 {
codeExp = DefaultCodeExp
}
ti.SetCodeCreateAt(createAt)
ti.SetCodeExpiresIn(codeExp)
if exp := tgr.AccessTokenExp; exp > 0 {
ti.SetAccessExpiresIn(exp)
}
if tgr.CodeChallenge != "" {
ti.SetCodeChallenge(tgr.CodeChallenge)
ti.SetCodeChallengeMethod(tgr.CodeChallengeMethod)
}
tv, err := m.authorizeGenerate.Token(ctx, td)
if err != nil {
return nil, err
}
ti.SetCode(tv)
case oauth2.Token:
// set access token expires
icfg := m.grantConfig(oauth2.Implicit)
aexp := icfg.AccessTokenExp
if exp := tgr.AccessTokenExp; exp > 0 {
aexp = exp
}
ti.SetAccessCreateAt(createAt)
ti.SetAccessExpiresIn(aexp)
if icfg.IsGenerateRefresh {
ti.SetRefreshCreateAt(createAt)
ti.SetRefreshExpiresIn(icfg.RefreshTokenExp)
}
tv, rv, err := m.accessGenerate.Token(ctx, td, icfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(tv)
if rv != "" {
ti.SetRefresh(rv)
}
}
err = m.tokenStore.Create(ctx, ti)
if err != nil {
return nil, err
}
return ti, nil
}
// get authorization code data
func (m *Manager) getAuthorizationCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
ti, err := m.tokenStore.GetByCode(ctx, code)
if err != nil {
return nil, err
} else if ti == nil || ti.GetCode() != code || ti.GetCodeCreateAt().Add(ti.GetCodeExpiresIn()).Before(time.Now()) {
err = errors.ErrInvalidAuthorizeCode
return nil, errors.ErrInvalidAuthorizeCode
}
return ti, nil
}
// delete authorization code data
func (m *Manager) delAuthorizationCode(ctx context.Context, code string) error {
return m.tokenStore.RemoveByCode(ctx, code)
}
// get and delete authorization code data
func (m *Manager) getAndDelAuthorizationCode(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
code := tgr.Code
ti, err := m.getAuthorizationCode(ctx, code)
if err != nil {
return nil, err
} else if ti.GetClientID() != tgr.ClientID {
return nil, errors.ErrInvalidAuthorizeCode
} else if codeURI := ti.GetRedirectURI(); codeURI != "" && codeURI != tgr.RedirectURI {
return nil, errors.ErrInvalidAuthorizeCode
}
err = m.delAuthorizationCode(ctx, code)
if err != nil {
return nil, err
}
return ti, nil
}
func (m *Manager) validateCodeChallenge(ti oauth2.TokenInfo, ver string) error {
cc := ti.GetCodeChallenge()
// early return
if cc == "" && ver == "" {
return nil
}
if cc == "" {
return errors.ErrMissingCodeVerifier
}
if ver == "" {
return errors.ErrMissingCodeVerifier
}
ccm := ti.GetCodeChallengeMethod()
if ccm.String() == "" {
ccm = oauth2.CodeChallengePlain
}
if !ccm.Validate(cc, ver) {
return errors.ErrInvalidCodeChallenge
}
return nil
}
// GenerateAccessToken generate the access token
func (m *Manager) GenerateAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
cli, err := m.GetClient(ctx, tgr.ClientID)
if err != nil {
return nil, err
}
if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
if !cliPass.VerifyPassword(tgr.ClientSecret) {
return nil, errors.ErrInvalidClient
}
} else if len(cli.GetSecret()) > 0 && tgr.ClientSecret != cli.GetSecret() {
return nil, errors.ErrInvalidClient
}
if tgr.RedirectURI != "" {
if err := m.validateURI(cli.GetDomain(), tgr.RedirectURI); err != nil {
return nil, err
}
}
if gt == oauth2.AuthorizationCode {
ti, err := m.getAndDelAuthorizationCode(ctx, tgr)
if err != nil {
return nil, err
}
if err := m.validateCodeChallenge(ti, tgr.CodeVerifier); err != nil {
return nil, err
}
tgr.UserID = ti.GetUserID()
tgr.Scope = ti.GetScope()
if exp := ti.GetAccessExpiresIn(); exp > 0 {
tgr.AccessTokenExp = exp
}
}
ti := models.NewToken()
ti.SetClientID(tgr.ClientID)
ti.SetUserID(tgr.UserID)
ti.SetRedirectURI(tgr.RedirectURI)
ti.SetScope(tgr.Scope)
createAt := time.Now()
ti.SetAccessCreateAt(createAt)
// set access token expires
gcfg := m.grantConfig(gt)
aexp := gcfg.AccessTokenExp
if exp := tgr.AccessTokenExp; exp > 0 {
aexp = exp
}
ti.SetAccessExpiresIn(aexp)
if gcfg.IsGenerateRefresh {
ti.SetRefreshCreateAt(createAt)
ti.SetRefreshExpiresIn(gcfg.RefreshTokenExp)
}
td := &oauth2.GenerateBasic{
Client: cli,
UserID: tgr.UserID,
CreateAt: createAt,
TokenInfo: ti,
Request: tgr.Request,
}
av, rv, err := m.accessGenerate.Token(ctx, td, gcfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(av)
if rv != "" {
ti.SetRefresh(rv)
}
err = m.tokenStore.Create(ctx, ti)
if err != nil {
return nil, err
}
return ti, nil
}
// RefreshAccessToken refreshing an access token
func (m *Manager) RefreshAccessToken(ctx context.Context, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) {
cli, err := m.GetClient(ctx, tgr.ClientID)
if err != nil {
return nil, err
} else if cliPass, ok := cli.(oauth2.ClientPasswordVerifier); ok {
if !cliPass.VerifyPassword(tgr.ClientSecret) {
return nil, errors.ErrInvalidClient
}
} else if tgr.ClientSecret != cli.GetSecret() {
return nil, errors.ErrInvalidClient
}
ti, err := m.LoadRefreshToken(ctx, tgr.Refresh)
if err != nil {
return nil, err
} else if ti.GetClientID() != tgr.ClientID {
return nil, errors.ErrInvalidRefreshToken
}
oldAccess, oldRefresh := ti.GetAccess(), ti.GetRefresh()
td := &oauth2.GenerateBasic{
Client: cli,
UserID: ti.GetUserID(),
CreateAt: time.Now(),
TokenInfo: ti,
Request: tgr.Request,
}
rcfg := DefaultRefreshTokenCfg
if v := m.rcfg; v != nil {
rcfg = v
}
ti.SetAccessCreateAt(td.CreateAt)
if v := rcfg.AccessTokenExp; v > 0 {
ti.SetAccessExpiresIn(v)
}
if v := rcfg.RefreshTokenExp; v > 0 {
ti.SetRefreshExpiresIn(v)
}
if rcfg.IsResetRefreshTime {
ti.SetRefreshCreateAt(td.CreateAt)
}
if scope := tgr.Scope; scope != "" {
ti.SetScope(scope)
}
tv, rv, err := m.accessGenerate.Token(ctx, td, rcfg.IsGenerateRefresh)
if err != nil {
return nil, err
}
ti.SetAccess(tv)
if rv != "" {
ti.SetRefresh(rv)
}
if err := m.tokenStore.Create(ctx, ti); err != nil {
return nil, err
}
if rcfg.IsRemoveAccess {
// remove the old access token
if err := m.tokenStore.RemoveByAccess(ctx, oldAccess); err != nil {
return nil, err
}
}
if rcfg.IsRemoveRefreshing && rv != "" {
// remove the old refresh token
if err := m.tokenStore.RemoveByRefresh(ctx, oldRefresh); err != nil {
return nil, err
}
}
if rv == "" {
ti.SetRefresh("")
ti.SetRefreshCreateAt(time.Now())
ti.SetRefreshExpiresIn(0)
}
return ti, nil
}
// RemoveAccessToken use the access token to delete the token information
func (m *Manager) RemoveAccessToken(ctx context.Context, access string) error {
if access == "" {
return errors.ErrInvalidAccessToken
}
return m.tokenStore.RemoveByAccess(ctx, access)
}
// RemoveRefreshToken use the refresh token to delete the token information
func (m *Manager) RemoveRefreshToken(ctx context.Context, refresh string) error {
if refresh == "" {
return errors.ErrInvalidAccessToken
}
return m.tokenStore.RemoveByRefresh(ctx, refresh)
}
// LoadAccessToken according to the access token for corresponding token information
func (m *Manager) LoadAccessToken(ctx context.Context, access string) (oauth2.TokenInfo, error) {
if access == "" {
return nil, errors.ErrInvalidAccessToken
}
ct := time.Now()
ti, err := m.tokenStore.GetByAccess(ctx, access)
if err != nil {
return nil, err
} else if ti == nil || ti.GetAccess() != access {
return nil, errors.ErrInvalidAccessToken
} else if ti.GetRefresh() != "" && ti.GetRefreshExpiresIn() != 0 &&
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
return nil, errors.ErrExpiredRefreshToken
} else if ti.GetAccessExpiresIn() != 0 &&
ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
return nil, errors.ErrExpiredAccessToken
}
return ti, nil
}
// LoadRefreshToken according to the refresh token for corresponding token information
func (m *Manager) LoadRefreshToken(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
if refresh == "" {
return nil, errors.ErrInvalidRefreshToken
}
ti, err := m.tokenStore.GetByRefresh(ctx, refresh)
if err != nil {
return nil, err
} else if ti == nil || ti.GetRefresh() != refresh {
return nil, errors.ErrInvalidRefreshToken
} else if ti.GetRefreshExpiresIn() != 0 && // refresh token set to not expire
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(time.Now()) {
return nil, errors.ErrExpiredRefreshToken
}
return ti, nil
}

View File

@ -1,30 +0,0 @@
package manage
import (
"net/url"
"strings"
"github.com/superseriousbusiness/oauth2/v4/errors"
)
type (
// ValidateURIHandler validates that redirectURI is contained in baseURI
ValidateURIHandler func(baseURI, redirectURI string) error
)
// DefaultValidateURI validates that redirectURI is contained in baseURI
func DefaultValidateURI(baseURI string, redirectURI string) error {
base, err := url.Parse(baseURI)
if err != nil {
return err
}
redirect, err := url.Parse(redirectURI)
if err != nil {
return err
}
if !strings.HasSuffix(redirect.Host, base.Host) {
return errors.ErrInvalidRedirectURI
}
return nil
}

View File

@ -1,59 +0,0 @@
package oauth2
import (
"time"
)
type (
// ClientInfo the client information model interface
ClientInfo interface {
GetID() string
GetSecret() string
GetDomain() string
GetUserID() string
}
// ClientPasswordVerifier the password handler interface
ClientPasswordVerifier interface {
VerifyPassword(string) bool
}
// TokenInfo the token information model interface
TokenInfo interface {
New() TokenInfo
GetClientID() string
SetClientID(string)
GetUserID() string
SetUserID(string)
GetRedirectURI() string
SetRedirectURI(string)
GetScope() string
SetScope(string)
GetCode() string
SetCode(string)
GetCodeCreateAt() time.Time
SetCodeCreateAt(time.Time)
GetCodeExpiresIn() time.Duration
SetCodeExpiresIn(time.Duration)
GetCodeChallenge() string
SetCodeChallenge(string)
GetCodeChallengeMethod() CodeChallengeMethod
SetCodeChallengeMethod(CodeChallengeMethod)
GetAccess() string
SetAccess(string)
GetAccessCreateAt() time.Time
SetAccessCreateAt(time.Time)
GetAccessExpiresIn() time.Duration
SetAccessExpiresIn(time.Duration)
GetRefresh() string
SetRefresh(string)
GetRefreshCreateAt() time.Time
SetRefreshCreateAt(time.Time)
GetRefreshExpiresIn() time.Duration
SetRefreshExpiresIn(time.Duration)
}
)

View File

@ -1,46 +0,0 @@
package models
// Client client model
type Client interface {
GetID() string
GetSecret() string
GetDomain() string
GetUserID() string
}
func New(id string, secret string, domain string, userID string) Client {
return &simpleClient{
id: id,
secret: secret,
domain: domain,
userID: userID,
}
}
// simpleClient is a very simple client model that satisfies the Client interface
type simpleClient struct {
id string
secret string
domain string
userID string
}
// GetID client id
func (c *simpleClient) GetID() string {
return c.id
}
// GetSecret client secret
func (c *simpleClient) GetSecret() string {
return c.secret
}
// GetDomain client domain
func (c *simpleClient) GetDomain() string {
return c.domain
}
// GetUserID user id
func (c *simpleClient) GetUserID() string {
return c.userID
}

View File

@ -1,186 +0,0 @@
package models
import (
"time"
"github.com/superseriousbusiness/oauth2/v4"
)
// NewToken create to token model instance
func NewToken() *Token {
return &Token{}
}
// Token token model
type Token struct {
ClientID string `bson:"ClientID"`
UserID string `bson:"UserID"`
RedirectURI string `bson:"RedirectURI"`
Scope string `bson:"Scope"`
Code string `bson:"Code"`
CodeChallenge string `bson:"CodeChallenge"`
CodeChallengeMethod string `bson:"CodeChallengeMethod"`
CodeCreateAt time.Time `bson:"CodeCreateAt"`
CodeExpiresIn time.Duration `bson:"CodeExpiresIn"`
Access string `bson:"Access"`
AccessCreateAt time.Time `bson:"AccessCreateAt"`
AccessExpiresIn time.Duration `bson:"AccessExpiresIn"`
Refresh string `bson:"Refresh"`
RefreshCreateAt time.Time `bson:"RefreshCreateAt"`
RefreshExpiresIn time.Duration `bson:"RefreshExpiresIn"`
}
// New create to token model instance
func (t *Token) New() oauth2.TokenInfo {
return NewToken()
}
// GetClientID the client id
func (t *Token) GetClientID() string {
return t.ClientID
}
// SetClientID the client id
func (t *Token) SetClientID(clientID string) {
t.ClientID = clientID
}
// GetUserID the user id
func (t *Token) GetUserID() string {
return t.UserID
}
// SetUserID the user id
func (t *Token) SetUserID(userID string) {
t.UserID = userID
}
// GetRedirectURI redirect URI
func (t *Token) GetRedirectURI() string {
return t.RedirectURI
}
// SetRedirectURI redirect URI
func (t *Token) SetRedirectURI(redirectURI string) {
t.RedirectURI = redirectURI
}
// GetScope get scope of authorization
func (t *Token) GetScope() string {
return t.Scope
}
// SetScope get scope of authorization
func (t *Token) SetScope(scope string) {
t.Scope = scope
}
// GetCode authorization code
func (t *Token) GetCode() string {
return t.Code
}
// SetCode authorization code
func (t *Token) SetCode(code string) {
t.Code = code
}
// GetCodeCreateAt create Time
func (t *Token) GetCodeCreateAt() time.Time {
return t.CodeCreateAt
}
// SetCodeCreateAt create Time
func (t *Token) SetCodeCreateAt(createAt time.Time) {
t.CodeCreateAt = createAt
}
// GetCodeExpiresIn the lifetime in seconds of the authorization code
func (t *Token) GetCodeExpiresIn() time.Duration {
return t.CodeExpiresIn
}
// SetCodeExpiresIn the lifetime in seconds of the authorization code
func (t *Token) SetCodeExpiresIn(exp time.Duration) {
t.CodeExpiresIn = exp
}
// GetCodeChallenge challenge code
func (t *Token) GetCodeChallenge() string {
return t.CodeChallenge
}
// SetCodeChallenge challenge code
func (t *Token) SetCodeChallenge(code string) {
t.CodeChallenge = code
}
// GetCodeChallengeMethod challenge method
func (t *Token) GetCodeChallengeMethod() oauth2.CodeChallengeMethod {
return oauth2.CodeChallengeMethod(t.CodeChallengeMethod)
}
// SetCodeChallengeMethod challenge method
func (t *Token) SetCodeChallengeMethod(method oauth2.CodeChallengeMethod) {
t.CodeChallengeMethod = string(method)
}
// GetAccess access Token
func (t *Token) GetAccess() string {
return t.Access
}
// SetAccess access Token
func (t *Token) SetAccess(access string) {
t.Access = access
}
// GetAccessCreateAt create Time
func (t *Token) GetAccessCreateAt() time.Time {
return t.AccessCreateAt
}
// SetAccessCreateAt create Time
func (t *Token) SetAccessCreateAt(createAt time.Time) {
t.AccessCreateAt = createAt
}
// GetAccessExpiresIn the lifetime in seconds of the access token
func (t *Token) GetAccessExpiresIn() time.Duration {
return t.AccessExpiresIn
}
// SetAccessExpiresIn the lifetime in seconds of the access token
func (t *Token) SetAccessExpiresIn(exp time.Duration) {
t.AccessExpiresIn = exp
}
// GetRefresh refresh Token
func (t *Token) GetRefresh() string {
return t.Refresh
}
// SetRefresh refresh Token
func (t *Token) SetRefresh(refresh string) {
t.Refresh = refresh
}
// GetRefreshCreateAt create Time
func (t *Token) GetRefreshCreateAt() time.Time {
return t.RefreshCreateAt
}
// SetRefreshCreateAt create Time
func (t *Token) SetRefreshCreateAt(createAt time.Time) {
t.RefreshCreateAt = createAt
}
// GetRefreshExpiresIn the lifetime in seconds of the refresh token
func (t *Token) GetRefreshExpiresIn() time.Duration {
return t.RefreshExpiresIn
}
// SetRefreshExpiresIn the lifetime in seconds of the refresh token
func (t *Token) SetRefreshExpiresIn(exp time.Duration) {
t.RefreshExpiresIn = exp
}

View File

@ -1,50 +0,0 @@
package server
import (
"net/http"
"time"
"github.com/superseriousbusiness/oauth2/v4"
)
// Config configuration parameters
type Config struct {
TokenType string // token type
AllowGetAccessRequest bool // to allow GET requests for the token
AllowedResponseTypes []oauth2.ResponseType // allow the authorization type
AllowedGrantTypes []oauth2.GrantType // allow the grant type
AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod
ForcePKCE bool
}
// NewConfig create to configuration instance
func NewConfig() *Config {
return &Config{
TokenType: "Bearer",
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code, oauth2.Token},
AllowedGrantTypes: []oauth2.GrantType{
oauth2.AuthorizationCode,
oauth2.PasswordCredentials,
oauth2.ClientCredentials,
oauth2.Refreshing,
},
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
oauth2.CodeChallengePlain,
oauth2.CodeChallengeS256,
},
}
}
// AuthorizeRequest authorization request
type AuthorizeRequest struct {
ResponseType oauth2.ResponseType
ClientID string
Scope string
RedirectURI string
State string
UserID string
CodeChallenge string
CodeChallengeMethod oauth2.CodeChallengeMethod
AccessTokenExp time.Duration
Request *http.Request
}

View File

@ -1,66 +0,0 @@
package server
import (
"net/http"
"time"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/errors"
)
type (
// ClientInfoHandler get client info from request
ClientInfoHandler func(r *http.Request) (clientID, clientSecret string, err error)
// ClientAuthorizedHandler check the client allows to use this authorization grant type
ClientAuthorizedHandler func(clientID string, grant oauth2.GrantType) (allowed bool, err error)
// ClientScopeHandler check the client allows to use scope
ClientScopeHandler func(tgr *oauth2.TokenGenerateRequest) (allowed bool, err error)
// UserAuthorizationHandler get user id from request authorization
UserAuthorizationHandler func(w http.ResponseWriter, r *http.Request) (userID string, err error)
// PasswordAuthorizationHandler get user id from username and password
PasswordAuthorizationHandler func(username, password string) (userID string, err error)
// RefreshingScopeHandler check the scope of the refreshing token
RefreshingScopeHandler func(tgr *oauth2.TokenGenerateRequest, oldScope string) (allowed bool, err error)
// RefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other
RefreshingValidationHandler func(ti oauth2.TokenInfo) (allowed bool, err error)
// ResponseErrorHandler response error handing
ResponseErrorHandler func(re *errors.Response)
// InternalErrorHandler internal error handing
InternalErrorHandler func(err error) (re *errors.Response)
// AuthorizeScopeHandler set the authorized scope
AuthorizeScopeHandler func(w http.ResponseWriter, r *http.Request) (scope string, err error)
// AccessTokenExpHandler set expiration date for the access token
AccessTokenExpHandler func(w http.ResponseWriter, r *http.Request) (exp time.Duration, err error)
// ExtensionFieldsHandler in response to the access token with the extension of the field
ExtensionFieldsHandler func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{})
)
// ClientFormHandler get client data from form
func ClientFormHandler(r *http.Request) (string, string, error) {
clientID := r.Form.Get("client_id")
if clientID == "" {
return "", "", errors.ErrInvalidClient
}
clientSecret := r.Form.Get("client_secret")
return clientID, clientSecret, nil
}
// ClientBasicHandler get client data from basic authorization
func ClientBasicHandler(r *http.Request) (string, string, error) {
username, password, ok := r.BasicAuth()
if !ok {
return "", "", errors.ErrInvalidClient
}
return username, password, nil
}

View File

@ -1,589 +0,0 @@
package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/superseriousbusiness/oauth2/v4"
"github.com/superseriousbusiness/oauth2/v4/errors"
)
// NewDefaultServer create a default authorization server
func NewDefaultServer(manager oauth2.Manager) *Server {
return NewServer(NewConfig(), manager)
}
// NewServer create authorization server
func NewServer(cfg *Config, manager oauth2.Manager) *Server {
srv := &Server{
Config: cfg,
Manager: manager,
}
// default handler
srv.ClientInfoHandler = ClientBasicHandler
srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) {
return "", errors.ErrAccessDenied
}
srv.PasswordAuthorizationHandler = func(username, password string) (string, error) {
return "", errors.ErrAccessDenied
}
return srv
}
// Server Provide authorization server
type Server struct {
Config *Config
Manager oauth2.Manager
ClientInfoHandler ClientInfoHandler
ClientAuthorizedHandler ClientAuthorizedHandler
ClientScopeHandler ClientScopeHandler
UserAuthorizationHandler UserAuthorizationHandler
PasswordAuthorizationHandler PasswordAuthorizationHandler
RefreshingValidationHandler RefreshingValidationHandler
RefreshingScopeHandler RefreshingScopeHandler
ResponseErrorHandler ResponseErrorHandler
InternalErrorHandler InternalErrorHandler
ExtensionFieldsHandler ExtensionFieldsHandler
AccessTokenExpHandler AccessTokenExpHandler
AuthorizeScopeHandler AuthorizeScopeHandler
}
func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error {
if req == nil {
return err
}
data, _, _ := s.GetErrorData(err)
return s.redirect(w, req, data)
}
func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error {
uri, err := s.GetRedirectURI(req, data)
if err != nil {
return err
}
w.Header().Set("Location", uri)
w.WriteHeader(302)
return nil
}
func (s *Server) tokenError(w http.ResponseWriter, err error) error {
data, statusCode, header := s.GetErrorData(err)
return s.token(w, data, header, statusCode)
}
func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error {
w.Header().Set("Content-Type", "application/json;charset=UTF-8")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
for key := range header {
w.Header().Set(key, header.Get(key))
}
status := http.StatusOK
if len(statusCode) > 0 && statusCode[0] > 0 {
status = statusCode[0]
}
w.WriteHeader(status)
return json.NewEncoder(w).Encode(data)
}
// GetRedirectURI get redirect uri
func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) {
u, err := url.Parse(req.RedirectURI)
if err != nil {
return "", err
}
q := u.Query()
if req.State != "" {
q.Set("state", req.State)
}
for k, v := range data {
q.Set(k, fmt.Sprint(v))
}
switch req.ResponseType {
case oauth2.Code:
u.RawQuery = q.Encode()
case oauth2.Token:
u.RawQuery = ""
fragment, err := url.QueryUnescape(q.Encode())
if err != nil {
return "", err
}
u.Fragment = fragment
}
return u.String(), nil
}
// CheckResponseType check allows response type
func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool {
for _, art := range s.Config.AllowedResponseTypes {
if art == rt {
return true
}
}
return false
}
// CheckCodeChallengeMethod checks for allowed code challenge method
func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool {
for _, c := range s.Config.AllowedCodeChallengeMethods {
if c == ccm {
return true
}
}
return false
}
// ValidationAuthorizeRequest the authorization request validation
func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) {
redirectURI := r.FormValue("redirect_uri")
clientID := r.FormValue("client_id")
if !(r.Method == "GET" || r.Method == "POST") ||
clientID == "" {
return nil, errors.ErrInvalidRequest
}
resType := oauth2.ResponseType(r.FormValue("response_type"))
if resType.String() == "" {
return nil, errors.ErrUnsupportedResponseType
} else if allowed := s.CheckResponseType(resType); !allowed {
return nil, errors.ErrUnauthorizedClient
}
cc := r.FormValue("code_challenge")
if cc == "" && s.Config.ForcePKCE {
return nil, errors.ErrCodeChallengeRquired
}
if cc != "" && (len(cc) < 43 || len(cc) > 128) {
return nil, errors.ErrInvalidCodeChallengeLen
}
ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method"))
// set default
if ccm == "" {
ccm = oauth2.CodeChallengePlain
}
if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) {
return nil, errors.ErrUnsupportedCodeChallengeMethod
}
req := &AuthorizeRequest{
RedirectURI: redirectURI,
ResponseType: resType,
ClientID: clientID,
State: r.FormValue("state"),
Scope: r.FormValue("scope"),
Request: r,
CodeChallenge: cc,
CodeChallengeMethod: ccm,
}
return req, nil
}
// GetAuthorizeToken get authorization token(code)
func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) {
// check the client allows the grant type
if fn := s.ClientAuthorizedHandler; fn != nil {
gt := oauth2.AuthorizationCode
if req.ResponseType == oauth2.Token {
gt = oauth2.Implicit
}
allowed, err := fn(req.ClientID, gt)
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrUnauthorizedClient
}
}
tgr := &oauth2.TokenGenerateRequest{
ClientID: req.ClientID,
UserID: req.UserID,
RedirectURI: req.RedirectURI,
Scope: req.Scope,
AccessTokenExp: req.AccessTokenExp,
Request: req.Request,
}
// check the client allows the authorized scope
if fn := s.ClientScopeHandler; fn != nil {
allowed, err := fn(tgr)
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrInvalidScope
}
}
tgr.CodeChallenge = req.CodeChallenge
tgr.CodeChallengeMethod = req.CodeChallengeMethod
return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr)
}
// GetAuthorizeData get authorization response data
func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} {
if rt == oauth2.Code {
return map[string]interface{}{
"code": ti.GetCode(),
}
}
return s.GetTokenData(ti)
}
// HandleAuthorizeRequest the authorization request handling
func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
req, err := s.ValidationAuthorizeRequest(r)
if err != nil {
return s.redirectError(w, req, err)
}
// user authorization
userID, err := s.UserAuthorizationHandler(w, r)
if err != nil {
return s.redirectError(w, req, err)
} else if userID == "" {
return nil
}
req.UserID = userID
// specify the scope of authorization
if fn := s.AuthorizeScopeHandler; fn != nil {
scope, err := fn(w, r)
if err != nil {
return err
} else if scope != "" {
req.Scope = scope
}
}
// specify the expiration time of access token
if fn := s.AccessTokenExpHandler; fn != nil {
exp, err := fn(w, r)
if err != nil {
return err
}
req.AccessTokenExp = exp
}
ti, err := s.GetAuthorizeToken(ctx, req)
if err != nil {
return s.redirectError(w, req, err)
}
// If the redirect URI is empty, the default domain provided by the client is used.
if req.RedirectURI == "" {
client, err := s.Manager.GetClient(ctx, req.ClientID)
if err != nil {
return err
}
req.RedirectURI = client.GetDomain()
}
return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti))
}
// ValidationTokenRequest the token request validation
func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) {
if v := r.Method; !(v == "POST" ||
(s.Config.AllowGetAccessRequest && v == "GET")) {
return "", nil, errors.ErrInvalidRequest
}
gt := oauth2.GrantType(r.FormValue("grant_type"))
if gt.String() == "" {
return "", nil, errors.ErrUnsupportedGrantType
}
if !s.CheckGrantType(gt) {
return "", nil, errors.ErrUnsupportedGrantType
}
clientID, clientSecret, err := s.ClientInfoHandler(r)
if err != nil {
return "", nil, err
}
tgr := &oauth2.TokenGenerateRequest{
ClientID: clientID,
ClientSecret: clientSecret,
Request: r,
}
switch gt {
case oauth2.AuthorizationCode:
tgr.RedirectURI = r.FormValue("redirect_uri")
tgr.Code = r.FormValue("code")
if tgr.RedirectURI == "" ||
tgr.Code == "" {
return "", nil, errors.ErrInvalidRequest
}
tgr.CodeVerifier = r.FormValue("code_verifier")
if s.Config.ForcePKCE && tgr.CodeVerifier == "" {
return "", nil, errors.ErrInvalidRequest
}
case oauth2.PasswordCredentials:
tgr.Scope = r.FormValue("scope")
username, password := r.FormValue("username"), r.FormValue("password")
if username == "" || password == "" {
return "", nil, errors.ErrInvalidRequest
}
userID, err := s.PasswordAuthorizationHandler(username, password)
if err != nil {
return "", nil, err
} else if userID == "" {
return "", nil, errors.ErrInvalidGrant
}
tgr.UserID = userID
case oauth2.ClientCredentials:
tgr.Scope = r.FormValue("scope")
tgr.RedirectURI = r.FormValue("redirect_uri")
case oauth2.Refreshing:
tgr.Refresh = r.FormValue("refresh_token")
tgr.Scope = r.FormValue("scope")
if tgr.Refresh == "" {
return "", nil, errors.ErrInvalidRequest
}
}
return gt, tgr, nil
}
// CheckGrantType check allows grant type
func (s *Server) CheckGrantType(gt oauth2.GrantType) bool {
for _, agt := range s.Config.AllowedGrantTypes {
if agt == gt {
return true
}
}
return false
}
// GetAccessToken access token
func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo,
error) {
if allowed := s.CheckGrantType(gt); !allowed {
return nil, errors.ErrUnauthorizedClient
}
if fn := s.ClientAuthorizedHandler; fn != nil {
allowed, err := fn(tgr.ClientID, gt)
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrUnauthorizedClient
}
}
switch gt {
case oauth2.AuthorizationCode:
ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr)
if err != nil {
switch err {
case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge:
return nil, errors.ErrInvalidGrant
case errors.ErrInvalidClient:
return nil, errors.ErrInvalidClient
default:
return nil, err
}
}
return ti, nil
case oauth2.PasswordCredentials, oauth2.ClientCredentials:
if fn := s.ClientScopeHandler; fn != nil {
allowed, err := fn(tgr)
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrInvalidScope
}
}
return s.Manager.GenerateAccessToken(ctx, gt, tgr)
case oauth2.Refreshing:
// check scope
if scopeFn := s.RefreshingScopeHandler; tgr.Scope != "" && scopeFn != nil {
rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
if err != nil {
if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
return nil, errors.ErrInvalidGrant
}
return nil, err
}
allowed, err := scopeFn(tgr, rti.GetScope())
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrInvalidScope
}
}
if validationFn := s.RefreshingValidationHandler; validationFn != nil {
rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh)
if err != nil {
if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
return nil, errors.ErrInvalidGrant
}
return nil, err
}
allowed, err := validationFn(rti)
if err != nil {
return nil, err
} else if !allowed {
return nil, errors.ErrInvalidScope
}
}
ti, err := s.Manager.RefreshAccessToken(ctx, tgr)
if err != nil {
if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken {
return nil, errors.ErrInvalidGrant
}
return nil, err
}
return ti, nil
}
return nil, errors.ErrUnsupportedGrantType
}
// GetTokenData token data
func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} {
data := map[string]interface{}{
"access_token": ti.GetAccess(),
"token_type": s.Config.TokenType,
"expires_in": int64(ti.GetAccessExpiresIn() / time.Second),
}
if scope := ti.GetScope(); scope != "" {
data["scope"] = scope
}
if refresh := ti.GetRefresh(); refresh != "" {
data["refresh_token"] = refresh
}
if fn := s.ExtensionFieldsHandler; fn != nil {
ext := fn(ti)
for k, v := range ext {
if _, ok := data[k]; ok {
continue
}
data[k] = v
}
}
return data
}
// HandleTokenRequest token request handling
func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
gt, tgr, err := s.ValidationTokenRequest(r)
if err != nil {
return s.tokenError(w, err)
}
ti, err := s.GetAccessToken(ctx, gt, tgr)
if err != nil {
return s.tokenError(w, err)
}
return s.token(w, s.GetTokenData(ti), nil)
}
// GetErrorData get error response data
func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) {
var re errors.Response
if v, ok := errors.Descriptions[err]; ok {
re.Error = err
re.Description = v
re.StatusCode = errors.StatusCodes[err]
} else {
if fn := s.InternalErrorHandler; fn != nil {
if v := fn(err); v != nil {
re = *v
}
}
if re.Error == nil {
re.Error = errors.ErrServerError
re.Description = errors.Descriptions[errors.ErrServerError]
re.StatusCode = errors.StatusCodes[errors.ErrServerError]
}
}
if fn := s.ResponseErrorHandler; fn != nil {
fn(&re)
}
data := make(map[string]interface{})
if err := re.Error; err != nil {
data["error"] = err.Error()
}
if v := re.ErrorCode; v != 0 {
data["error_code"] = v
}
if v := re.Description; v != "" {
data["error_description"] = v
}
if v := re.URI; v != "" {
data["error_uri"] = v
}
statusCode := http.StatusInternalServerError
if v := re.StatusCode; v > 0 {
statusCode = v
}
return data, statusCode, re.Header
}
// BearerAuth parse bearer token
func (s *Server) BearerAuth(r *http.Request) (string, bool) {
auth := r.Header.Get("Authorization")
prefix := "Bearer "
token := ""
if auth != "" && strings.HasPrefix(auth, prefix) {
token = auth[len(prefix):]
} else {
token = r.FormValue("access_token")
}
return token, token != ""
}
// ValidationBearerToken validation the bearer tokens
// https://tools.ietf.org/html/rfc6750
func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) {
ctx := r.Context()
accessToken, ok := s.BearerAuth(r)
if !ok {
return nil, errors.ErrInvalidAccessToken
}
return s.Manager.LoadAccessToken(ctx, accessToken)
}

View File

@ -1,85 +0,0 @@
package server
import (
"github.com/superseriousbusiness/oauth2/v4"
)
// SetTokenType token type
func (s *Server) SetTokenType(tokenType string) {
s.Config.TokenType = tokenType
}
// SetAllowGetAccessRequest to allow GET requests for the token
func (s *Server) SetAllowGetAccessRequest(allow bool) {
s.Config.AllowGetAccessRequest = allow
}
// SetAllowedResponseType allow the authorization types
func (s *Server) SetAllowedResponseType(types ...oauth2.ResponseType) {
s.Config.AllowedResponseTypes = types
}
// SetAllowedGrantType allow the grant types
func (s *Server) SetAllowedGrantType(types ...oauth2.GrantType) {
s.Config.AllowedGrantTypes = types
}
// SetClientInfoHandler get client info from request
func (s *Server) SetClientInfoHandler(handler ClientInfoHandler) {
s.ClientInfoHandler = handler
}
// SetClientAuthorizedHandler check the client allows to use this authorization grant type
func (s *Server) SetClientAuthorizedHandler(handler ClientAuthorizedHandler) {
s.ClientAuthorizedHandler = handler
}
// SetClientScopeHandler check the client allows to use scope
func (s *Server) SetClientScopeHandler(handler ClientScopeHandler) {
s.ClientScopeHandler = handler
}
// SetUserAuthorizationHandler get user id from request authorization
func (s *Server) SetUserAuthorizationHandler(handler UserAuthorizationHandler) {
s.UserAuthorizationHandler = handler
}
// SetPasswordAuthorizationHandler get user id from username and password
func (s *Server) SetPasswordAuthorizationHandler(handler PasswordAuthorizationHandler) {
s.PasswordAuthorizationHandler = handler
}
// SetRefreshingScopeHandler check the scope of the refreshing token
func (s *Server) SetRefreshingScopeHandler(handler RefreshingScopeHandler) {
s.RefreshingScopeHandler = handler
}
// SetRefreshingValidationHandler check if refresh_token is still valid. eg no revocation or other
func (s *Server) SetRefreshingValidationHandler(handler RefreshingValidationHandler) {
s.RefreshingValidationHandler = handler
}
// SetResponseErrorHandler response error handling
func (s *Server) SetResponseErrorHandler(handler ResponseErrorHandler) {
s.ResponseErrorHandler = handler
}
// SetInternalErrorHandler internal error handling
func (s *Server) SetInternalErrorHandler(handler InternalErrorHandler) {
s.InternalErrorHandler = handler
}
// SetExtensionFieldsHandler in response to the access token with the extension of the field
func (s *Server) SetExtensionFieldsHandler(handler ExtensionFieldsHandler) {
s.ExtensionFieldsHandler = handler
}
// SetAccessTokenExpHandler set expiration date for the access token
func (s *Server) SetAccessTokenExpHandler(handler AccessTokenExpHandler) {
s.AccessTokenExpHandler = handler
}
// SetAuthorizeScopeHandler set scope for the access token
func (s *Server) SetAuthorizeScopeHandler(handler AuthorizeScopeHandler) {
s.AuthorizeScopeHandler = handler
}

View File

@ -1,36 +0,0 @@
package oauth2
import "context"
type (
// ClientStore the client information storage interface
ClientStore interface {
GetByID(ctx context.Context, id string) (ClientInfo, error)
Set(ctx context.Context, id string, cli ClientInfo) error
Delete(ctx context.Context, id string) error
}
// TokenStore the token information storage interface
TokenStore interface {
// create and store the new token information
Create(ctx context.Context, info TokenInfo) error
// delete the authorization code
RemoveByCode(ctx context.Context, code string) error
// use the access token to delete the token information
RemoveByAccess(ctx context.Context, access string) error
// use the refresh token to delete the token information
RemoveByRefresh(ctx context.Context, refresh string) error
// use the authorization code for token information data
GetByCode(ctx context.Context, code string) (TokenInfo, error)
// use the access token for token information data
GetByAccess(ctx context.Context, access string) (TokenInfo, error)
// use the refresh token for token information data
GetByRefresh(ctx context.Context, refresh string) (TokenInfo, error)
}
)