Skip to content

Commit

Permalink
Implemented file uploading api.
Browse files Browse the repository at this point in the history
  • Loading branch information
zensh committed Sep 7, 2023
1 parent 0497d59 commit eb4369e
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 97 deletions.
5 changes: 2 additions & 3 deletions config/default.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ webscraper = "http://127.0.0.1:8080"
walletbase = "http://127.0.0.1:8080"

[oss]
bucket = "yiwenai"
bucket = "ywfs"
endpoint = "oss-cn-hangzhou.aliyuncs.com"
access_key_id = ""
access_key_secret = ""
prefix = "dev/cr/"
url_base = "https://cdn.yiwen.pub/"
base_url = "https://fs.yiwen.pub/"

[[recommendations]]
gid = "cil6ehjmps48vprp24f0"
Expand Down
20 changes: 20 additions & 0 deletions src/api/creation.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/yiwen-ai/yiwen-api/src/content"
"github.com/yiwen-ai/yiwen-api/src/logging"
"github.com/yiwen-ai/yiwen-api/src/middleware"
"github.com/yiwen-ai/yiwen-api/src/service"
"github.com/yiwen-ai/yiwen-api/src/util"
)

Expand Down Expand Up @@ -485,6 +486,25 @@ func (a *Creation) UpdateContent(ctx *gear.Context) error {
return ctx.OkSend(bll.SuccessResponse[*bll.CreationOutput]{Result: output})
}

func (a *Creation) UploadFile(ctx *gear.Context) error {
input := &bll.QueryCreation{}
if err := ctx.ParseBody(input); err != nil {
return err
}

creation, err := a.checkWritePermission(ctx, input.GID, input.ID)
if err != nil {
return err
}

if *creation.Status != 0 && *creation.Status != 1 {
return gear.ErrBadRequest.WithMsg("cannot update creation content, status is not 0 or 1")
}

output := a.blls.Writing.SignPostPolicy(creation.GID, creation.ID, *creation.Language, uint(*creation.Version))
return ctx.OkSend(bll.SuccessResponse[service.PostFilePolicy]{Result: output})
}

func (a *Creation) checkReadPermission(ctx *gear.Context, gid util.ID) error {
sess := gear.CtxValue[middleware.Session](ctx)
role, err := a.blls.Userbase.UserGroupRole(ctx, sess.UserID, gid)
Expand Down
21 changes: 20 additions & 1 deletion src/api/publication.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/yiwen-ai/yiwen-api/src/content"
"github.com/yiwen-ai/yiwen-api/src/logging"
"github.com/yiwen-ai/yiwen-api/src/middleware"
"github.com/yiwen-ai/yiwen-api/src/service"
"github.com/yiwen-ai/yiwen-api/src/util"
)

Expand Down Expand Up @@ -700,7 +701,7 @@ func (a *Publication) ListArchived(ctx *gear.Context) error {
}

func (a *Publication) GetPublishList(ctx *gear.Context) error {
input := &bll.QueryAllPublish{}
input := &bll.GidCidInput{}
if err := ctx.ParseURL(input); err != nil {
return err
}
Expand Down Expand Up @@ -907,6 +908,24 @@ func (a *Publication) Collect(ctx *gear.Context) error {
return ctx.OkSend(bll.SuccessResponse[*bll.CollectionOutput]{Result: output})
}

func (a *Publication) UploadFile(ctx *gear.Context) error {
input := &bll.QueryPublication{}
if err := ctx.ParseBody(input); err != nil {
return err
}
publication, err := a.checkWritePermission(ctx, input.GID, input.CID, input.Language, input.Version)
if err != nil {
return err
}

if *publication.Status != 0 {
return gear.ErrBadRequest.WithMsg("cannot update publication content, status is not 0 or 1")
}

output := a.blls.Writing.SignPostPolicy(publication.GID, publication.CID, publication.Language, uint(publication.Version))
return ctx.OkSend(bll.SuccessResponse[service.PostFilePolicy]{Result: output})
}

func (a *Publication) checkReadPermission(ctx *gear.Context, gid util.ID) (int8, error) {
sess := gear.CtxValue[middleware.Session](ctx)
if sess.UserID == util.ANON {
Expand Down
2 changes: 2 additions & 0 deletions src/api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func newRouters(apis *APIs) []*gear.Router {
router.Put("/v1/creation/update_content", middleware.AuthToken.Auth, apis.Creation.UpdateContent)
router.Patch("/v1/creation/update_content", middleware.AuthToken.Auth, todo) // 暂不实现
router.Post("/v1/creation/assist", middleware.AuthToken.Auth, todo) // 暂不实现
router.Post("/v1/creation/upload", middleware.AuthToken.Auth, apis.Creation.UploadFile)

router.Post("/v1/publication", middleware.AuthToken.Auth, apis.Publication.Create)
router.Post("/v1/publication/estimate", middleware.AuthToken.Auth, apis.Publication.Estimate)
Expand All @@ -99,6 +100,7 @@ func newRouters(apis *APIs) []*gear.Router {
router.Put("/v1/publication/update_content", middleware.AuthToken.Auth, apis.Publication.UpdateContent)
router.Post("/v1/publication/assist", middleware.AuthToken.Auth, todo) // 暂不实现
router.Post("/v1/publication/collect", middleware.AuthToken.Auth, apis.Publication.Collect)
router.Post("/v1/publication/upload", middleware.AuthToken.Auth, apis.Publication.UploadFile)

router.Patch("/v1/collection", middleware.AuthToken.Auth, apis.Collection.Update)
router.Delete("/v1/collection", middleware.AuthToken.Auth, apis.Collection.Delete)
Expand Down
13 changes: 13 additions & 0 deletions src/bll/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,16 @@ func (i *QueryIdCn) Validate() error {
}
return nil
}

type GidCidInput struct {
GID util.ID `json:"gid" cbor:"gid" query:"gid" validate:"required"`
CID util.ID `json:"cid" cbor:"cid" query:"cid" validate:"required"`
}

func (i *GidCidInput) Validate() error {
if err := util.Validator.Struct(i); err != nil {
return gear.ErrBadRequest.From(err)
}

return nil
}
4 changes: 4 additions & 0 deletions src/bll/writing.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,7 @@ func (b *Writing) OriginalSearch(ctx context.Context, input *ScrapingInput) Sear

return output.Result
}

func (b *Writing) SignPostPolicy(gid, cid util.ID, lang string, version uint) service.PostFilePolicy {
return b.oss.SignPostPolicy(gid.String(), cid.String(), lang, version)
}
17 changes: 2 additions & 15 deletions src/bll/writing_publication.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (i *PublicationOutput) IntoPublicationDraft(gid util.ID, language, model st

func (b *Writing) InitApp(ctx context.Context, _ *gear.App) error {
for _, v := range conf.Config.Recommendations {
res, err := b.GetPublicationList(ctx, 2, &QueryAllPublish{
res, err := b.GetPublicationList(ctx, 2, &GidCidInput{
GID: v.GID,
CID: v.CID,
})
Expand Down Expand Up @@ -381,20 +381,7 @@ func (b *Writing) ListPublicationByGIDs(ctx context.Context, input *GIDsPaginati
return &output, nil
}

type QueryAllPublish struct {
GID util.ID `json:"gid" cbor:"gid" query:"gid" validate:"required"`
CID util.ID `json:"cid" cbor:"cid" query:"cid" validate:"required"`
}

func (i *QueryAllPublish) Validate() error {
if err := util.Validator.Struct(i); err != nil {
return gear.ErrBadRequest.From(err)
}

return nil
}

func (b *Writing) GetPublicationList(ctx context.Context, from_status int8, input *QueryAllPublish) (*SuccessResponse[PublicationOutputs], error) {
func (b *Writing) GetPublicationList(ctx context.Context, from_status int8, input *GidCidInput) (*SuccessResponse[PublicationOutputs], error) {
output := SuccessResponse[PublicationOutputs]{}
query := url.Values{}
query.Add("gid", input.GID.String())
Expand Down
3 changes: 1 addition & 2 deletions src/conf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ type OSS struct {
Endpoint string `json:"endpoint" toml:"endpoint"`
AccessKeyId string `json:"access_key_id" toml:"access_key_id"`
AccessKeySecret string `json:"access_key_secret" toml:"access_key_secret"`
Prefix string `json:"prefix" toml:"prefix"`
UrlBase string `json:"url_base" toml:"url_base"`
BaseUrl string `json:"url_base" toml:"base_url"`
}

type Recommendation struct {
Expand Down
133 changes: 61 additions & 72 deletions src/service/oss.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
package service

import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"net/url"
"strings"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"fmt"
"time"

"github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/teambition/gear"

"github.com/yiwen-ai/yiwen-api/src/conf"
"github.com/yiwen-ai/yiwen-api/src/util"
Expand All @@ -22,9 +19,8 @@ func init() {
}

type OSS struct {
UrlBase string
Prefix string
bucket *oss.Bucket
cfg conf.OSS
bucket *oss.Bucket
}

func NewOSS() *OSS {
Expand All @@ -39,73 +35,66 @@ func NewOSS() *OSS {
}

return &OSS{
UrlBase: cfg.UrlBase,
Prefix: cfg.Prefix,
bucket: bucket,
cfg: cfg,
bucket: bucket,
}
}

func (s *OSS) SavePicture(ctx context.Context, imgPath, imgUrl string) (string, error) {
ctype, reader, err := GetPicture(ctx, imgUrl)
if err != nil {
return "", err
}

objectKey := s.Prefix + imgPath
if err := s.bucket.PutObject(objectKey, reader, oss.ContentType(ctype),
oss.CacheControl("public"), oss.ContentDisposition("inline")); err != nil {
return "", err
}

return s.UrlBase + objectKey, nil
// https://help.aliyun.com/zh/oss/developer-reference/postobject
// 如 base_url 为 https://fs.yiwen.pub/grv5...9pjc/cil6...4f0/1/zho/
// 上传了文件 yiwen.ai.png,
// 则该文件访问链接为 https://fs.yiwen.pub/grv5...9pjc/cil6...4f0/1/zho/yiwen.ai.png
type PostFilePolicy struct {
Host string `json:"host" cbor:"host"`
Dir string `json:"dir" cbor:"dir"`
AccessKey string `json:"access_key" cbor:"access_key"`
Policy string `json:"policy" cbor:"policy"`
Signature string `json:"signature" cbor:"signature"`
BaseUrl string `json:"base_url" cbor:"base_url"`
}

func GetPicture(ctx context.Context, imgUrl string) (string, io.ReadCloser, error) {
req, err := http.NewRequestWithContext(ctx, "GET", imgUrl, nil)
if err != nil {
return "", nil, err
// 指定过期时间,单位为秒。
const ossExpiration = 3600 * time.Second
const ossMinContentLength = 1024
const ossMaxContentLength = 1024 * 1024 * 10
const ossCacheControl = "public, max-age=604800, immutable"
const ossContentDisposition = "inline"

var ossContentType = []string{"image/jpg", "image/png", "image/gif", "image/jpeg", "image/webp"}

// https://help.aliyun.com/zh/oss/use-cases/client-direct-transmission-overview
func (s *OSS) SignPostPolicy(gid, cid, lang string, version uint) PostFilePolicy {
expiration := time.Now().Add(ossExpiration).UTC().Format("2006-01-02T15:04:05.999Z")
// https://help.aliyun.com/zh/oss/use-cases/oss-performance-and-scalability-best-practices
// 反转打散分区,避免热点
dir := fmt.Sprintf("%s/%s/%d/%s/", util.Reverse(cid), gid, version, lang)

data, _ := json.Marshal(map[string]any{
"expiration": expiration,
"conditions": []any{
// map[string]string{"bucket": "ywfs"},
[]any{"content-length-range", ossMinContentLength, ossMaxContentLength},
[]any{"starts-with", "$key", dir},
[]any{"in", "$content-type", ossContentType},
[]any{"eq", "$cache-control", ossCacheControl},
[]any{"eq", "$content-disposition", ossContentDisposition},
},
})

policy := base64.StdEncoding.EncodeToString(data)
hm := hmac.New(sha1.New, []byte(s.cfg.AccessKeySecret))
hm.Write([]byte(policy))
pp := PostFilePolicy{
Host: fmt.Sprintf("https://%s.%s", s.cfg.Bucket, s.cfg.Endpoint),
Dir: dir,
AccessKey: s.cfg.AccessKeyId,
Policy: policy,
Signature: base64.StdEncoding.EncodeToString(hm.Sum(nil)),
BaseUrl: s.cfg.BaseUrl + dir,
}

resp, err := fileHTTPClient.Do(req)
if err != nil {
if err.(*url.Error).Unwrap() == context.Canceled {
return "", nil, gear.ErrClientClosedRequest
}

return "", nil, err
}

if resp.StatusCode != http.StatusOK {
data, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return "", nil, gear.Err.WithCode(resp.StatusCode).WithMsg(string(data))
}

ct := strings.ToLower(resp.Header.Get(gear.HeaderContentType))
if !strings.Contains(ct, "image") {
resp.Body.Close()
return "", nil, gear.ErrUnsupportedMediaType.WithMsg(ct)
}

return ct, resp.Body, nil
}

var tr = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: false},
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 15 * time.Second,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 20,
IdleConnTimeout: 25 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 10 * time.Second,
ResponseHeaderTimeout: 15 * time.Second,
return pp
}

var fileHTTPClient = &http.Client{
Transport: tr,
Timeout: time.Second * 60,
func (s *OSS) ListObjects(cid string) (any, error) {
return s.bucket.ListObjectsV2(oss.Prefix(fmt.Sprintf("%s/", util.Reverse(cid))))
}
13 changes: 13 additions & 0 deletions src/util/common.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package util

import "unicode/utf8"

func SliceHas[T comparable](sl []T, v T) bool {
for _, s := range sl {
if v == s {
Expand All @@ -22,3 +24,14 @@ func RemoveDuplicates[T comparable](sl []T) []T {
}
return res
}

func Reverse(s string) string {
size := len(s)
buf := make([]byte, size)
for start := 0; start < size; {
r, n := utf8.DecodeRuneInString(s[start:])
start += n
utf8.EncodeRune(buf[size-start:], r)
}
return string(buf)
}
6 changes: 6 additions & 0 deletions src/util/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ func TestRemoveDuplicates(t *testing.T) {
id2 := mustParseID(id.String())
assert.Equal(t, RemoveDuplicates([]ID{id, id2}), []ID{id})
}

func TestReverse(t *testing.T) {
for _, s := range []string{"Hello, 世界", Ptr(NewID()).String()} {
assert.Equal(t, s, Reverse(Reverse(s)))
}
}
11 changes: 7 additions & 4 deletions src/util/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ import (
var onceTK sync.Once
var tk *tiktoken.Tiktoken
var tokensRate = map[string]float32{
"eng": 1.0,
"zho": 1.29,
"jpn": 1.88,
"eng": 1.00,
"zho": 1.20,
"jpn": 1.65,
"fra": 1.31,
"kor": 1.57,
"ara": 2.10,
}

const MAX_TOKENS = 64 * 1024 // 64k
const MAX_TOKENS = 128 * 1024

func init() {
onceTK.Do(func() {
Expand Down

0 comments on commit eb4369e

Please sign in to comment.